未验证 提交 269bd1fe 编写于 作者: P piotrekobi 提交者: GitHub

[PHI] Move oneDNN helper classes to new location (#45626)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* remove fluid code

* onednn renaming

* Move classes from mkldnn_reuse.h to onednn_reuse.h

* Move more functions from mkldnn_helper.h to onednn_helpper.h

* Change MKLDNN to OneDNN in VLOG message
Co-authored-by: NSilv3S <slawomir.siwek@intel.com>
上级 4e4f4586
...@@ -212,10 +212,10 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -212,10 +212,10 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
out_layout = framework::DataLayout::kNCHW; out_layout = framework::DataLayout::kNCHW;
} }
if (in_layout == framework::DataLayout::MKLDNN && if (in_layout == framework::DataLayout::ONEDNN &&
out_layout != framework::DataLayout::MKLDNN) { out_layout != framework::DataLayout::ONEDNN) {
auto target_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); auto target_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
VLOG(4) << "TransDataLayoutFromMKLDNN: " << in_layout << "->" VLOG(4) << "TransDataLayoutFromOneDNN: " << in_layout << "->"
<< target_layout; << target_layout;
if (out_layout == DataLayout::kNCHW && if (out_layout == DataLayout::kNCHW &&
......
...@@ -75,7 +75,7 @@ TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) { ...@@ -75,7 +75,7 @@ TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) {
auto kernel_key_mkldnn = auto kernel_key_mkldnn =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn); paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn);
ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32); ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN); ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::ONEDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::ONEDNN); ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::ONEDNN);
#endif #endif
......
...@@ -39,537 +39,24 @@ template <typename T, ...@@ -39,537 +39,24 @@ template <typename T,
typename TForward, typename TForward,
typename TBackward = mkldnn_dummy_primitive, typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive> typename TBackward_params = mkldnn_dummy_primitive>
using MKLDNNHandlerNoCachingT = phi::funcs:: using MKLDNNHandlerT =
MKLDNNHandlerNoCachingT<T, TForward, TBackward, TBackward_params>; phi::funcs::OneDNNHandlerT<T, TForward, TBackward, TBackward_params>;
template <typename T, template <typename T,
typename TForward, typename TForward,
typename TBackward = mkldnn_dummy_primitive, typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive> typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerT { using MKLDNNHandlerNoCachingT = phi::funcs::
public: OneDNNHandlerNoCachingT<T, TForward, TBackward, TBackward_params>;
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx,
dnnl::engine engine,
platform::Place cpu_place,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
place_(cpu_place),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@fwd_p";
auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) {
forward_p = std::make_shared<TForward>(*fwd_pd_);
dev_ctx_.SetBlob(key_p, forward_p);
}
return forward_p;
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
const std::string key_p = key_ + "@bwd_p";
auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
backward_p = std::make_shared<TBackward>(*bwd_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
const std::string key_p = key_ + "@bwd_w_p";
auto backward_p =
std::static_pointer_cast<TBackward_params>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable("BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
backward_p = std::make_shared<TBackward_params>(*bwd_w_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(
fwd_pd_->dst_desc(), ptr, "@dst_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(
const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data),
"@bwd-dst_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr), "@diff_dst_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_src_desc(), ptr, "@diff_src_mem_p");
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(
bwd_w_pd_->diff_weights_desc(), ptr, "@diff_wei_mem_p");
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
"@diff_wei_mem_p");
}
protected:
bool isCached() {
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
}
bool isBwdCached() {
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
return false;
} else {
if (std::is_same<TBackward_params, mkldnn_dummy_primitive>::value ==
false) {
const std::string key_bw_w_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_bw_w_pd));
}
// When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fpd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable(
"Error: FWD PD should be set when BWD PD is cached."));
return true;
}
}
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
// This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_w_pd_ == nullptr) {
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_w_pd_);
}
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) {
return std::static_pointer_cast<dnnl::memory>(
dev_ctx_.GetBlob(key_ + suffix));
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, void* ptr, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md,
void* ptr,
const std::string& suffix,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx_.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
dnnl::reorder::primitive_desc reorder_pdesc;
if (is_int8<T>()) {
dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data);
reorder_pdesc = dnnl::reorder::primitive_desc(
*user_memory_p, *target_memory_p, attr);
} else {
reorder_pdesc =
dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
}
auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
std::shared_ptr<dnnl::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix;
return std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
}
const MKLDNNDeviceContext& dev_ctx_;
dnnl::engine engine_;
platform::Place place_;
std::string key_common_;
std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
template <typename T> template <typename T>
class BinaryMKLDNNHandler using ReductionMKLDNNHandler = phi::funcs::ReductionOneDNNHandler<T>;
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo,
const int axis,
const dnnl::engine engine,
platform::Place cpu_place,
const Tensor* x,
const Tensor* y,
Tensor* out,
float scale_x,
float scale_y,
float scale_out,
const dnnl::post_ops& post_ops = dnnl::post_ops{})
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src_x_tz = phi::vectorize(x->dims());
const auto src_y_tz = phi::vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: phi::vectorize(out->dims());
auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc();
if (rankdiff > 0) { // Second input is of smaller rank than first
std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
src_y_tz.begin(),
src_y_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
std::rotate(dims1_ex.begin() + 1, dims1_ex.end() - 1, dims1_ex.end());
}
src1_md = src1_md.reshape(dims1_ex);
} else if (rankdiff < 0) { // First input is of smaller than second
std::vector<int64_t> dims0_ex(-rankdiff, 1);
dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)),
src_x_tz.begin(),
src_x_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
std::rotate(dims0_ex.begin() + 1, dims0_ex.end() - 1, dims0_ex.end());
}
src0_md = src0_md.reshape(dims0_ex);
}
const auto dst_md = memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto attributes =
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
if (x->numel() < y->numel()) {
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src1_md, src0_md, dst_md);
} else {
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src0_md, src1_md, dst_md);
}
}
std::shared_ptr<dnnl::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
to_void_cast<T>(input_data));
}
private:
static inline dnnl::primitive_attr CreateAttributes(
dnnl::algorithm op,
float scale_x,
float scale_y,
float scale_out,
dnnl::post_ops post_ops = dnnl::post_ops{}) {
// Scales set in attributes for inputs contibute to the output equation
// in the following way (assuming no broadcasting takes place):
// output_i = scale_0 * x_i <+ or *> scale_1 * y_i;
// Hence we have to create scales that will:
// 1. Dequantize both values, by multiplying with (1.0 / scale_x_or_y)
// 2. Quantize their result to output scale range, by multiplying with
// (scale_z)
// If we combine these two, we end up with following equation
// output = scale_out * (1/scale_x * x <* or +> 1/scale_y * y)
// Hence, to mimic such behaviour using provided interface,
// For add operation the equation is equal to:
// output = (scale_out / scale_x) * x + (scale_out / scale_y) * y
// <scale_0> <scale_1>
// For mul operation on the other hand
// output = (scale_out / scale_x) * x * (1.0 / scale_y) * y
// <scale_0> <scale_1>
float scale_0 = scale_out / scale_x;
float scale_1 =
op == dnnl::algorithm::binary_add ? scale_out / scale_y : 1.0 / scale_y;
dnnl::primitive_attr attributes;
attributes.set_scales(
/* input_x_id = */ DNNL_ARG_SRC_0, /* mask = */ 0, {scale_0});
attributes.set_scales(
/* input_y_id = */ DNNL_ARG_SRC_1, /* mask = */ 0, {scale_1});
if (post_ops.len() > 0) attributes.set_post_ops(post_ops);
return attributes;
}
};
template <typename T> template <typename T>
class BroadcastDataMKLDNNHandler using BroadcastDataMKLDNNHandler = phi::funcs::BroadcastDataOneDNNHandler<T>;
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine,
platform::Place cpu_place,
const Tensor* x,
Tensor* out,
float scale_x,
float scale_y,
const std::vector<int64_t>& extended_x_dims)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_tz = phi::vectorize(out->dims());
const auto src0_md =
dnnl::memory::desc(src0_tz,
platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims);
dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src0_md, src1_md, src0_md);
}
template <typename T_out = T> template <typename T>
std::shared_ptr<dnnl::memory> AcquireZeroedDstMemory(framework::Tensor* out) { using BinaryMKLDNNHandler = phi::funcs::BinaryOneDNNHandler<T>;
T_out* ptr = out->mutable_data<T_out>(this->place_,
this->fwd_pd_->dst_desc().get_size());
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
static void AppendActivation(const framework::ExecutionContext& ctx, static void AppendActivation(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops, // NOLINT dnnl::post_ops& post_ops, // NOLINT
...@@ -624,34 +111,6 @@ static void AppendActivation(const framework::ExecutionContext& ctx, ...@@ -624,34 +111,6 @@ static void AppendActivation(const framework::ExecutionContext& ctx,
} }
} }
template <typename T>
class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> {
public:
ReductionMKLDNNHandler(const dnnl::algorithm algo,
const float p,
const float eps,
const dnnl::engine engine,
platform::Place cpu_place,
const Tensor* x,
const Tensor* out,
std::vector<int64_t> out_tz,
const dnnl::primitive_attr& attrs = NULL)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
cpu_place) {
const auto out_md = memory::desc(out_tz,
platform::MKLDNNGetDataType<T>(),
dnnl::memory::format_tag::any);
if (attrs)
this->AcquireForwardPrimitiveDescriptor(
attrs, algo, x->mem_desc(), out_md, p, eps);
else
this->AcquireForwardPrimitiveDescriptor(
algo, x->mem_desc(), out_md, p, eps);
}
};
template <typename T> template <typename T>
constexpr bool IsInt8() { constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
...@@ -1071,37 +530,5 @@ class ReorderMKLDNNHandler { ...@@ -1071,37 +530,5 @@ class ReorderMKLDNNHandler {
dnnl::memory::data_type dtype_, dtype_dst_; dnnl::memory::data_type dtype_, dtype_dst_;
dnnl::engine engine_; dnnl::engine engine_;
}; };
template <typename T>
static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx,
framework::Tensor* output,
std::vector<int64_t> dst_tz,
const dnnl::engine& engine,
std::shared_ptr<dnnl::memory::desc>& dst_md, // NOLINT
std::shared_ptr<dnnl::memory>& dst_memory, // NOLINT
MKLDNNMemoryFormat output_format) {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size();
MKLDNNMemoryFormat dst_fmt;
PADDLE_ENFORCE_LE(dst_dims,
5,
platform::errors::InvalidArgument(
"Dst memory for quantization can not have "
"dims > 5. But received dst_dims is %d.",
dst_dims));
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, output_format);
auto tmp_dst_md =
platform::MKLDNNMemDesc({dst_tz},
paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<T>::DataType()),
dst_fmt);
dst_md.reset(new dnnl::memory::desc(tmp_dst_md));
dst_memory.reset(
new dnnl::memory(*dst_md, engine, to_void_cast<T>(output_data)));
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -56,7 +56,7 @@ BackendSet GetTensorBackendSet(const phi::TensorBase& t) { ...@@ -56,7 +56,7 @@ BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) { if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) {
BackendSet backend_set(phi::TransToPhiBackend(t.place())); BackendSet backend_set(phi::TransToPhiBackend(t.place()));
switch (t.layout()) { switch (t.layout()) {
case DataLayout::MKLDNN: case DataLayout::ONEDNN:
backend_set = backend_set | BackendSet(Backend::ONEDNN); backend_set = backend_set | BackendSet(Backend::ONEDNN);
break; break;
default: default:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <thread>
#include "dnnl.hpp" // NOLINT #include "dnnl.hpp" // NOLINT
#include "glog/logging.h" #include "glog/logging.h"
...@@ -94,6 +95,106 @@ inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { ...@@ -94,6 +95,106 @@ inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) {
} }
} }
template <typename Type>
dnnl::memory::data_type oneDNNGetDataType() {
return dnnl::memory::data_type::undef;
}
template <>
inline dnnl::memory::data_type oneDNNGetDataType<float>() {
return dnnl::memory::data_type::f32;
}
template <>
inline dnnl::memory::data_type oneDNNGetDataType<int32_t>() {
return dnnl::memory::data_type::s32;
}
template <>
inline dnnl::memory::data_type oneDNNGetDataType<int8_t>() {
return dnnl::memory::data_type::s8;
}
template <>
inline dnnl::memory::data_type oneDNNGetDataType<uint8_t>() {
return dnnl::memory::data_type::u8;
}
template <>
inline dnnl::memory::data_type oneDNNGetDataType<dtype::bfloat16>() {
return dnnl::memory::data_type::bf16;
}
inline std::vector<std::vector<int64_t>> ToOneDNNPadding(
const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) {
int padding_front = paddings[0];
int padding_back = paddings[1];
int padding_top = paddings[2];
int padding_bottom = paddings[3];
int padding_left = paddings[4];
int padding_right = paddings[5];
return {{padding_front, padding_top, padding_left},
{padding_back, padding_bottom, padding_right}};
} else {
int padding_top = paddings[0];
int padding_bottom = paddings[1];
int padding_left = paddings[2];
int padding_right = paddings[3];
return {{padding_top, padding_left}, {padding_bottom, padding_right}};
}
}
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
key->append(std::to_string(num));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format)));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type)));
}
template <>
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm)));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags)));
}
inline void AppendKey(std::string* key, const std::string& str) {
key->append(str);
}
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
for (size_t i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
template <typename... ArgTypes>
inline std::string CreateKey(const OneDNNContext& dev_ctx, ArgTypes&&... args) {
std::string key;
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += OneDNNContext::tls().get_key_suffix();
return key;
}
inline void MatchShapeToLayout(DenseTensor* tensor_in, inline void MatchShapeToLayout(DenseTensor* tensor_in,
DataLayout from, DataLayout from,
DataLayout to) { DataLayout to) {
...@@ -117,28 +218,28 @@ inline void MatchShapeToLayout(DenseTensor* tensor_in, ...@@ -117,28 +218,28 @@ inline void MatchShapeToLayout(DenseTensor* tensor_in,
// at last nhwC, so for dim==2 these layouts are the same and nothing should // at last nhwC, so for dim==2 these layouts are the same and nothing should
// be done. Similarly for dim==1 when you have just one possible combination. // be done. Similarly for dim==1 when you have just one possible combination.
if (tensor_in->dims().size() < 3) { if (tensor_in->dims().size() < 3) {
VLOG(3) << "Keeping MKLDNN/NHWC/NDHWC output_shape" VLOG(3) << "Keeping ONEDNN/NHWC/NDHWC output_shape"
<< print_dims(phi::vectorize<int>(tensor_in->dims())); << print_dims(phi::vectorize<int>(tensor_in->dims()));
return; return;
} }
switch (from) { switch (from) {
case DataLayout::MKLDNN: case DataLayout::ONEDNN:
if ((to == DataLayout::NHWC) || (to == DataLayout::NDHWC)) { if ((to == DataLayout::NHWC) || (to == DataLayout::NDHWC)) {
auto dims = phi::vectorize<int>(tensor_in->dims()); auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
tensor_in->Resize(phi::make_ddim(dims)); tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: MKLDNN to: NHWC/NDHWC output_shape" VLOG(3) << "Rotating Shape from: ONEDNN to: NHWC/NDHWC output_shape"
<< print_dims(dims); << print_dims(dims);
} }
break; break;
case DataLayout::NHWC: case DataLayout::NHWC:
case DataLayout::NDHWC: case DataLayout::NDHWC:
if (to == DataLayout::MKLDNN) { if (to == DataLayout::ONEDNN) {
auto dims = phi::vectorize<int>(tensor_in->dims()); auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.end() - 1, dims.end()); std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
tensor_in->Resize(phi::make_ddim(dims)); tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: NHWC/NDHWC to: MKLDNN output_shape" VLOG(3) << "Rotating Shape from: NHWC/NDHWC to: ONEDNN output_shape"
<< print_dims(dims); << print_dims(dims);
} }
break; break;
...@@ -158,5 +259,22 @@ inline dnnl::memory::desc OneDNNMemDesc(const std::vector<int64_t>& dims, ...@@ -158,5 +259,22 @@ inline dnnl::memory::desc OneDNNMemDesc(const std::vector<int64_t>& dims,
return dnnl::memory::desc({dims}, data_type, format); return dnnl::memory::desc({dims}, data_type, format);
} }
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
inline std::string ExtendKeyWithThreadInfoIfNeeded(const OneDNNContext& dev_ctx,
const std::string& key) {
return (OneDNNContext::tls().is_tid_used_in_key() == true)
? key + "-t:" + ThreadIDasStr()
: key;
}
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -33,19 +33,402 @@ namespace funcs { ...@@ -33,19 +33,402 @@ namespace funcs {
using user_function = std::function<std::shared_ptr<float>(const float*)>; using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = dnnl::memory; using memory = dnnl::memory;
using Place = phi::Place;
using MKLDNNMemoryFormat = dnnl::memory::format_tag; using OneDNNMemoryFormat = dnnl::memory::format_tag;
template <typename T, template <typename T,
typename TForward, typename TForward,
typename TBackward = onednn_dummy_primitive, typename TBackward = onednn_dummy_primitive,
typename TBackward_params = onednn_dummy_primitive> typename TBackward_params = onednn_dummy_primitive>
class MKLDNNHandlerNoCachingT { class OneDNNHandlerT {
public: public:
MKLDNNHandlerNoCachingT(dnnl::engine engine, Place cpu_place) OneDNNHandlerT(const OneDNNContext& dev_ctx,
dnnl::engine engine,
Place cpu_place,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
place_(cpu_place),
key_common_(base_key),
key_(ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
OneDNNContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@fwd_p";
auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) {
forward_p = std::make_shared<TForward>(*fwd_pd_);
dev_ctx_.SetBlob(key_p, forward_p);
}
return forward_p;
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
const std::string key_p = key_ + "@bwd_p";
auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
backward_p = std::make_shared<TBackward>(*bwd_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
const std::string key_p = key_ + "@bwd_w_p";
auto backward_p =
std::static_pointer_cast<TBackward_params>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
errors::Unavailable("BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
backward_p = std::make_shared<TBackward_params>(*bwd_w_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(
fwd_pd_->dst_desc(), ptr, "@dst_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p");
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(const DenseTensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data),
"@bwd-dst_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const DenseTensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr), "@diff_dst_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(DenseTensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_src_desc(), ptr, "@diff_src_mem_p");
}
// Buffer of given DenseTensor is used for oneDNN computation
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
DenseTensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(
bwd_w_pd_->diff_weights_desc(), ptr, "@diff_wei_mem_p");
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
"@diff_wei_mem_p");
}
protected:
bool isCached() {
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
}
bool isBwdCached() {
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
return false;
} else {
if (std::is_same<TBackward_params, onednn_dummy_primitive>::value ==
false) {
const std::string key_bw_w_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_bw_w_pd));
}
// When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fpd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
errors::Unavailable(
"Error: FWD PD should be set when BWD PD is cached."));
return true;
}
}
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
// This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
errors::Unavailable("Get OneDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
errors::Unavailable("Get OneDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_w_pd_ == nullptr) {
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_w_pd_);
}
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) {
return std::static_pointer_cast<dnnl::memory>(
dev_ctx_.GetBlob(key_ + suffix));
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, void* ptr, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
dnnl::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder(
"int_reorder",
paddle::platform::TracerEventType::UserDefined,
2,
paddle::platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md,
void* ptr,
const std::string& suffix,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx_.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
dnnl::reorder::primitive_desc reorder_pdesc;
if (is_int8<T>()) {
dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data);
reorder_pdesc = dnnl::reorder::primitive_desc(
*user_memory_p, *target_memory_p, attr);
} else {
reorder_pdesc =
dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
}
auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
auto& astream = OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder(
"int_reorder",
paddle::platform::TracerEventType::UserDefined,
2,
paddle::platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
auto& astream = OneDNNContext::tls().get_stream();
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
paddle::platform::RecordEvent record_reorder(
"int_reorder",
paddle::platform::TracerEventType::UserDefined,
2,
paddle::platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
std::shared_ptr<dnnl::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix;
return std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
}
const OneDNNContext& dev_ctx_;
dnnl::engine engine_;
Place place_;
std::string key_common_;
std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
template <typename T,
typename TForward,
typename TBackward = onednn_dummy_primitive,
typename TBackward_params = onednn_dummy_primitive>
class OneDNNHandlerNoCachingT {
public:
OneDNNHandlerNoCachingT(dnnl::engine engine, Place cpu_place)
: engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) { : engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) {
phi::OneDNNContext::tls().log_lib_version(); OneDNNContext::tls().log_lib_version();
} }
std::shared_ptr<TForward> AcquireForwardPrimitive() { std::shared_ptr<TForward> AcquireForwardPrimitive() {
...@@ -57,10 +440,9 @@ class MKLDNNHandlerNoCachingT { ...@@ -57,10 +440,9 @@ class MKLDNNHandlerNoCachingT {
} }
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() { std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_,
bwd_w_pd_, errors::Unavailable("BWD_PD should be set when "
phi::errors::Unavailable("BWD_PD should be set when " "getting BWD prim ."));
"getting BWD prim ."));
return std::make_shared<TBackward_params>(*bwd_w_pd_); return std::make_shared<TBackward_params>(*bwd_w_pd_);
} }
...@@ -102,12 +484,12 @@ class MKLDNNHandlerNoCachingT { ...@@ -102,12 +484,12 @@ class MKLDNNHandlerNoCachingT {
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr); return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr);
} }
// Buffer of given Tensor is used for oneDNN computation // Buffer of given DenseTensor is used for oneDNN computation
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory( std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
DenseTensor* diff_weights) { DenseTensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
phi::errors::Unavailable( errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights.")); "BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>( T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size()); place_, bwd_w_pd_->diff_weights_desc().get_size());
...@@ -119,7 +501,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -119,7 +501,7 @@ class MKLDNNHandlerNoCachingT {
std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) { std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
phi::errors::Unavailable( errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights.")); "BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc()); return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc());
} }
...@@ -161,7 +543,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -161,7 +543,7 @@ class MKLDNNHandlerNoCachingT {
// AcquireForwardPrimitiveDescriptor // AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, fwd_pd_,
phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed.")); errors::Unavailable("Get oneDNN Forward primitive %s failed."));
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...); auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>( bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_); bwd_desc, engine_, *fwd_pd_);
...@@ -173,7 +555,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -173,7 +555,7 @@ class MKLDNNHandlerNoCachingT {
// AcquireForwardPrimitiveDescriptor // AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, fwd_pd_,
phi::errors::Unavailable("Get MKLDNN Forward primitive %s failed.")); errors::Unavailable("Get oneDNN Forward primitive %s failed."));
auto bwd_desc = auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...); typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>( bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
...@@ -195,7 +577,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -195,7 +577,7 @@ class MKLDNNHandlerNoCachingT {
auto reorder_p = auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p); std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = phi::OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder( paddle::platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
...@@ -227,7 +609,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -227,7 +609,7 @@ class MKLDNNHandlerNoCachingT {
auto reorder_p = auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p); std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = phi::OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
paddle::platform::RecordEvent record_reorder( paddle::platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
paddle::platform::TracerEventType::UserDefined, paddle::platform::TracerEventType::UserDefined,
...@@ -252,7 +634,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -252,7 +634,7 @@ class MKLDNNHandlerNoCachingT {
template <typename T> template <typename T>
class ActivationOneDNNHandler class ActivationOneDNNHandler
: public MKLDNNHandlerNoCachingT<T, : public OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward, dnnl::eltwise_forward,
dnnl::eltwise_backward> { dnnl::eltwise_backward> {
public: public:
...@@ -262,7 +644,7 @@ class ActivationOneDNNHandler ...@@ -262,7 +644,7 @@ class ActivationOneDNNHandler
const dnnl::engine engine, const dnnl::engine engine,
Place cpu_place, Place cpu_place,
const DenseTensor* x) const DenseTensor* x)
: MKLDNNHandlerNoCachingT<T, : OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward, dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) { dnnl::eltwise_backward>(engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
...@@ -279,7 +661,7 @@ class ActivationOneDNNHandler ...@@ -279,7 +661,7 @@ class ActivationOneDNNHandler
Place cpu_place, Place cpu_place,
const DenseTensor* x, const DenseTensor* x,
const DenseTensor* dout) const DenseTensor* dout)
: MKLDNNHandlerNoCachingT<T, : OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward, dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) { dnnl::eltwise_backward>(engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
...@@ -330,7 +712,7 @@ class ReorderOneDNNHandler { ...@@ -330,7 +712,7 @@ class ReorderOneDNNHandler {
return std::make_shared<dnnl::memory>(md, engine_, ptr); return std::make_shared<dnnl::memory>(md, engine_, ptr);
} }
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt, std::shared_ptr<dnnl::memory> AcquireSrcMemory(const OneDNNMemoryFormat& fmt,
void* ptr) { void* ptr) {
auto md = dnnl::memory::desc(dims_, dtype_, fmt); auto md = dnnl::memory::desc(dims_, dtype_, fmt);
return std::make_shared<dnnl::memory>(md, engine_, ptr); return std::make_shared<dnnl::memory>(md, engine_, ptr);
...@@ -347,7 +729,7 @@ class ReorderOneDNNHandler { ...@@ -347,7 +729,7 @@ class ReorderOneDNNHandler {
} }
std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output, std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output,
const MKLDNNMemoryFormat& fmt, const OneDNNMemoryFormat& fmt,
Place place) { Place place) {
auto dst_md = OneDNNMemDesc(dims_, dtype_dst_, fmt); auto dst_md = OneDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size());
...@@ -372,7 +754,7 @@ class ReorderOneDNNHandler { ...@@ -372,7 +754,7 @@ class ReorderOneDNNHandler {
std::shared_ptr<dnnl::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
DenseTensor* output, DenseTensor* output,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat& fmt, const OneDNNMemoryFormat& fmt,
Place place) { Place place) {
auto dst_md = OneDNNMemDesc(dims, dtype_dst_, fmt); auto dst_md = OneDNNMemDesc(dims, dtype_dst_, fmt);
auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size()); auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size());
...@@ -400,5 +782,170 @@ class ReorderOneDNNHandler { ...@@ -400,5 +782,170 @@ class ReorderOneDNNHandler {
dnnl::engine engine_; dnnl::engine engine_;
}; };
template <typename T>
class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BinaryOneDNNHandler(const dnnl::algorithm algo,
const int axis,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
const DenseTensor* y,
DenseTensor* out,
float scale_x,
float scale_y,
float scale_out,
const dnnl::post_ops& post_ops = dnnl::post_ops{})
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src_x_tz = vectorize(x->dims());
const auto src_y_tz = vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: vectorize(out->dims());
auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc();
if (rankdiff > 0) { // Second input is of smaller rank than first
std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
src_y_tz.begin(),
src_y_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (OneDNNContext::tls().get_cur_paddle_data_layout() ==
DataLayout::kNHWC) {
std::rotate(dims1_ex.begin() + 1, dims1_ex.end() - 1, dims1_ex.end());
}
src1_md = src1_md.reshape(dims1_ex);
} else if (rankdiff < 0) { // First input is of smaller than second
std::vector<int64_t> dims0_ex(-rankdiff, 1);
dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)),
src_x_tz.begin(),
src_x_tz.end());
// For broadcasting for NHWC we need rotate extended shape
if (OneDNNContext::tls().get_cur_paddle_data_layout() ==
DataLayout::kNHWC) {
std::rotate(dims0_ex.begin() + 1, dims0_ex.end() - 1, dims0_ex.end());
}
src0_md = src0_md.reshape(dims0_ex);
}
const auto dst_md =
memory::desc(dst_tz, oneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
auto attributes =
CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
if (x->numel() < y->numel()) {
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src1_md, src0_md, dst_md);
} else {
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src0_md, src1_md, dst_md);
}
}
std::shared_ptr<dnnl::memory> AcquireSecondSrcMemory(
const DenseTensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
to_void_cast<T>(input_data));
}
private:
static inline dnnl::primitive_attr CreateAttributes(
dnnl::algorithm op,
float scale_x,
float scale_y,
float scale_out,
dnnl::post_ops post_ops = dnnl::post_ops{}) {
// Scales set in attributes for inputs contibute to the output equation
// in the following way (assuming no broadcasting takes place):
// output_i = scale_0 * x_i <+ or *> scale_1 * y_i;
// Hence we have to create scales that will:
// 1. Dequantize both values, by multiplying with (1.0 / scale_x_or_y)
// 2. Quantize their result to output scale range, by multiplying with
// (scale_z)
// If we combine these two, we end up with following equation
// output = scale_out * (1/scale_x * x <* or +> 1/scale_y * y)
// Hence, to mimic such behaviour using provided interface,
// For add operation the equation is equal to:
// output = (scale_out / scale_x) * x + (scale_out / scale_y) * y
// <scale_0> <scale_1>
// For mul operation on the other hand
// output = (scale_out / scale_x) * x * (1.0 / scale_y) * y
// <scale_0> <scale_1>
float scale_0 = scale_out / scale_x;
float scale_1 =
op == dnnl::algorithm::binary_add ? scale_out / scale_y : 1.0 / scale_y;
dnnl::primitive_attr attributes;
attributes.set_scales(
/* input_x_id = */ DNNL_ARG_SRC_0, /* mask = */ 0, {scale_0});
attributes.set_scales(
/* input_y_id = */ DNNL_ARG_SRC_1, /* mask = */ 0, {scale_1});
if (post_ops.len() > 0) attributes.set_post_ops(post_ops);
return attributes;
}
};
template <typename T>
class BroadcastDataOneDNNHandler
: public OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BroadcastDataOneDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
DenseTensor* out,
float scale_x,
float scale_y,
const std::vector<int64_t>& extended_x_dims)
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_tz = vectorize(out->dims());
const auto src0_md = dnnl::memory::desc(
src0_tz, oneDNNGetDataType<T>(), GetPlainOneDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims);
dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
this->AcquireForwardPrimitiveDescriptor(
attributes, algo, src0_md, src1_md, src0_md);
}
template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireZeroedDstMemory(DenseTensor* out) {
T_out* ptr = out->mutable_data<T_out>(this->place_,
this->fwd_pd_->dst_desc().get_size());
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
template <typename T>
class ReductionOneDNNHandler
: public OneDNNHandlerNoCachingT<T, dnnl::reduction> {
public:
ReductionOneDNNHandler(const dnnl::algorithm algo,
const float p,
const float eps,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
const DenseTensor* out,
std::vector<int64_t> out_tz,
const dnnl::primitive_attr& attrs = NULL)
: OneDNNHandlerNoCachingT<T, dnnl::reduction>(engine, cpu_place) {
const auto out_md = memory::desc(
out_tz, oneDNNGetDataType<T>(), dnnl::memory::format_tag::any);
if (attrs)
this->AcquireForwardPrimitiveDescriptor(
attrs, algo, x->mem_desc(), out_md, p, eps);
else
this->AcquireForwardPrimitiveDescriptor(
algo, x->mem_desc(), out_md, p, eps);
}
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -32,7 +32,7 @@ namespace experimental { ...@@ -32,7 +32,7 @@ namespace experimental {
* more specific, we need to distinguish the calculation method. * more specific, we need to distinguish the calculation method.
* *
* Such as the kernel for CPU device, it can be a native CPU kernel, * Such as the kernel for CPU device, it can be a native CPU kernel,
* or a kernel implemented by MKLDNN library. * or a kernel implemented by oneDNN library.
* *
* Note(chenweihang): HIP is not needed now, we can added it if needed * Note(chenweihang): HIP is not needed now, we can added it if needed
* in the future * in the future
......
...@@ -40,7 +40,7 @@ enum class DataLayout { ...@@ -40,7 +40,7 @@ enum class DataLayout {
NCHW, NCHW,
NCDHW, NCDHW,
NDHWC, NDHWC,
MKLDNN, ONEDNN,
SPARSE_COO, SPARSE_COO,
SPARSE_CSR, SPARSE_CSR,
PSTRING_UNION, PSTRING_UNION,
...@@ -62,7 +62,7 @@ enum class DataLayout { ...@@ -62,7 +62,7 @@ enum class DataLayout {
kAnyLayout = ANY, kAnyLayout = ANY,
kNHWC = NHWC, kNHWC = NHWC,
kNCHW = NCHW, kNCHW = NCHW,
kMKLDNN = MKLDNN, // all layouts supported by MKLDNN internally kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally
kNDHWC = NDHWC, kNDHWC = NDHWC,
kNCDHW = NCDHW, kNCDHW = NCDHW,
}; };
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/phi/kernels/log_softmax_kernel.h" #include "paddle/phi/kernels/log_softmax_kernel.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -23,16 +23,15 @@ ...@@ -23,16 +23,15 @@
namespace phi { namespace phi {
template <typename T> template <typename T>
class LogSoftmaxMKLDNNHandler class LogSoftmaxOneDNNHandler
: public paddle::platform:: : public funcs::OneDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> {
MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> {
public: public:
LogSoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine, LogSoftmaxOneDNNHandler(const dnnl::engine onednn_engine,
Place cpu_place, Place cpu_place,
const DenseTensor& x, const DenseTensor& x,
const int axis) const int axis)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>( : funcs::OneDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
mkldnn_engine, cpu_place) { onednn_engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, x.mem_desc(), axis); dnnl::prop_kind::forward_inference, x.mem_desc(), axis);
} }
...@@ -43,11 +42,11 @@ void LogSoftmaxKernel(const Context& dev_ctx, ...@@ -43,11 +42,11 @@ void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
axis = axis >= 0 ? axis : x.dims().size() + axis; axis = axis >= 0 ? axis : x.dims().size() + axis;
LogSoftmaxMKLDNNHandler<T> handler( LogSoftmaxOneDNNHandler<T> handler(
mkldnn_engine, dev_ctx.GetPlace(), x, axis); onednn_engine, dev_ctx.GetPlace(), x, axis);
auto src_memory_p = handler.AcquireSrcMemory(&x); auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out); auto dst_memory_p = handler.AcquireDstMemory(out);
......
...@@ -97,7 +97,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, ...@@ -97,7 +97,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
// NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in // NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
// data_transfer.cc // data_transfer.cc
if (!x.IsInitialized() && src_layout == DataLayout::MKLDNN && if (!x.IsInitialized() && src_layout == DataLayout::ONEDNN &&
dst_layout == DataLayout::NHWC) { dst_layout == DataLayout::NHWC) {
VLOG(4) << src_layout << "->" << dst_layout << " " << x.layout(); VLOG(4) << src_layout << "->" << dst_layout << " " << x.layout();
out->Resize(x.dims()); out->Resize(x.dims());
...@@ -106,7 +106,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, ...@@ -106,7 +106,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
return; return;
} }
if (src_layout != DataLayout::MKLDNN && dst_layout == DataLayout::MKLDNN) { if (src_layout != DataLayout::ONEDNN && dst_layout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
auto out_format = funcs::OneDNNFormatForSize( auto out_format = funcs::OneDNNFormatForSize(
...@@ -121,16 +121,16 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, ...@@ -121,16 +121,16 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
OneDNNContext::tls().set_cur_paddle_data_layout(src_layout); OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
} }
out->set_layout(DataLayout::MKLDNN); out->set_layout(DataLayout::ONEDNN);
out->set_format(out_format); out->set_format(out_format);
} else if (src_layout == DataLayout::MKLDNN && } else if (src_layout == DataLayout::ONEDNN &&
dst_layout != DataLayout::MKLDNN) { dst_layout != DataLayout::ONEDNN) {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib // Do transform via MKLDNN lib
funcs::innerTransDataLayoutFromOneDNN( funcs::innerTransDataLayoutFromOneDNN(
src_layout, dst_layout, x, out, dev_ctx.GetPlace()); src_layout, dst_layout, x, out, dev_ctx.GetPlace());
} else if (src_layout == DataLayout::MKLDNN && } else if (src_layout == DataLayout::ONEDNN &&
dst_layout == DataLayout::MKLDNN) { dst_layout == DataLayout::ONEDNN) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
src_layout, src_layout,
dst_layout, dst_layout,
......
...@@ -37,7 +37,7 @@ TEST(DataLayout, OStream) { ...@@ -37,7 +37,7 @@ TEST(DataLayout, OStream) {
oss << phi::DataLayout::NCHW; oss << phi::DataLayout::NCHW;
EXPECT_EQ(oss.str(), "NCHW"); EXPECT_EQ(oss.str(), "NCHW");
oss.str(""); oss.str("");
oss << phi::DataLayout::MKLDNN; oss << phi::DataLayout::ONEDNN;
EXPECT_EQ(oss.str(), "MKLDNN"); EXPECT_EQ(oss.str(), "MKLDNN");
oss.str(""); oss.str("");
try { try {
......
...@@ -40,7 +40,7 @@ TEST(DEV_API, transfer_layout) { ...@@ -40,7 +40,7 @@ TEST(DEV_API, transfer_layout) {
DenseTensor x; DenseTensor x;
MetaTensor meta_x(&x); MetaTensor meta_x(&x);
meta_x.set_dtype(DataType::FLOAT32); meta_x.set_dtype(DataType::FLOAT32);
meta_x.set_layout(DataLayout::MKLDNN); meta_x.set_layout(DataLayout::ONEDNN);
meta_x.set_dims(make_ddim({n, c, h, w})); meta_x.set_dims(make_ddim({n, c, h, w}));
DenseTensor out; DenseTensor out;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册