提交 d7652d5f 编写于 作者: J Jacek Czaja

first commit

上级 19746835
......@@ -28,7 +28,6 @@ using dnnl::primitive;
using dnnl::stream;
using framework::DataLayout;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
......
......@@ -30,7 +30,6 @@ using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using framework::DataLayout;
using platform::GetMKLDNNFormat;
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
......
......@@ -30,7 +30,6 @@ using framework::DataLayout;
using framework::DDim;
using framework::ExecutionContext;
using framework::LoDTensor;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
......
......@@ -25,7 +25,6 @@ using dnnl::reorder;
using dnnl::resampling_forward;
using dnnl::stream;
using framework::DataLayout;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
template <typename T = float>
......
......@@ -17,7 +17,6 @@ namespace {
using dnnl::memory;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNFormatForSize;
......
......@@ -20,7 +20,6 @@ namespace operators {
using dnnl::memory;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
......
......@@ -29,7 +29,6 @@ using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using framework::DataLayout;
using platform::GetMKLDNNFormat;
template <typename T>
class QuantOpKernel : public framework::OpKernel<T> {
......
......@@ -31,7 +31,6 @@ namespace paddle {
namespace operators {
using paddle::framework::LoDTensor;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
static std::vector<int> extract_shape(
......@@ -89,8 +88,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, getPlainFormatTag(x), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
......@@ -98,9 +97,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
out->Resize(out_dims);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims))));
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims)));
}
void InferInOutShape(const framework::ExecutionContext& ctx,
......@@ -362,17 +359,15 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
dout->format(), platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->Resize(dx_dims);
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(dx_dims))));
dx->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(phi::vectorize(dx_dims)));
}
void InferOutputShapeInGrad(const framework::ExecutionContext& ctx,
......
......@@ -202,164 +202,6 @@ inline void Reorder(dnnl::memory src,
astream.wait();
}
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return dnnl::memory::format_tag::nc;
} else {
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return dnnl::memory::format_tag::ntc;
} else {
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else {
return dnnl::memory::format_tag::dcab;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde4b;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return dnnl::memory::format_tag::undef;
}
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
auto mem_desc = memory.get_desc();
return GetMKLDNNFormat(mem_desc);
}
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
switch (tensor_rank) {
case 1:
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// NOTE:
// GetMKLDNNFormat function is here temporarily. It is
// needed because without them forward declaration was causing an error when
// building with "-DWITH_TESTING=ON". This file will be deleted after completing
// md-related refactoring
namespace paddle {
namespace platform {
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return dnnl::memory::format_tag::nc;
} else {
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return dnnl::memory::format_tag::ntc;
} else {
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else {
return dnnl::memory::format_tag::dcab;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return dnnl::memory::format_tag::undef;
}
} // namespace platform
} // namespace paddle
......@@ -123,12 +123,6 @@ inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
meta_.layout = DataLayout::kMKLDNN;
}
dnnl::memory::format_tag format() const;
inline void set_format(const dnnl::memory::format_tag format) {
format_ = format;
}
#endif
/* ------------------------------ */
......
......@@ -351,15 +351,7 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
#ifdef PADDLE_WITH_MKLDNN
dnnl::memory::desc DenseTensor::mem_desc() const {
return mem_desc_ ? mem_desc_
: dnnl::memory::desc(phi::vectorize(meta_.dims),
phi::TransToOneDNNDataType(meta_.dtype),
format_);
}
dnnl::memory::format_tag DenseTensor::format() const {
return mem_desc_ ? paddle::platform::GetMKLDNNFormat(mem_desc_) : format_;
}
return mem_desc_;
#endif
// NOTE: For historical reasons, this interface has a special behavior,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册