未验证 提交 14f261ad 编写于 作者: J Jacek Czaja 提交者: GitHub

Final changes to introduce mem_desc to be hold in Tensor (#46768)

* first commit

- more fixes

- compilation fix

- compilation fix

- fix

- another fix

- yet another fix

- Fix

- fix to fused ops

- compilation fix

- compilation fix

- another compilation fix

- another fix

- fix

- fix

- fix

- fix

- yet another fix

- fix

- fix

- cosmetic fix

:- lint

- Revert some changes (to be brought back later)

- fix to build

- Added prototype of slice

- fix

compilation fix

- compilation fix

- fix

- fix

- Fix

- fix

 fix
	modified:   cmake/flags.cmake

* lint

* rerun of CI

* - Fix

* - lint

* - lint2
上级 339aefac
......@@ -1245,11 +1245,6 @@ std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) {
os << " - shape: [" << t.dims() << "]\n";
os << " - layout: " << phi::DataLayoutToString(t.layout()) << "\n";
#ifdef PADDLE_WITH_MKLDNN
os << " - format: "
<< dnnl_fmt_tag2str(static_cast<dnnl_format_tag_t>(t.format())) << "\n";
#endif
DenseTensor tensor;
tensor.Resize(t.dims());
if (paddle::platform::is_cpu_place(t.place())) {
......
......@@ -77,9 +77,12 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
}
}
bool is_NTC() {
return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) ==
dnnl::memory::format_tag::ntc);
bool is_NTC() { return this->is_NTC(this->fwd_pd_->dst_desc()); }
bool is_NTC(const dnnl::memory::desc& md) {
auto ntc_md = dnnl::memory::desc(
md.dims(), md.data_type(), dnnl::memory::format_tag::ntc);
return md == ntc_md;
}
void reorderRNNdata(void* input_data,
......@@ -165,8 +168,7 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC);
if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) ==
dnnl::memory::format_tag::ntc) {
if (is_NTC(this->fwd_pd_->src_desc())) {
reorderRNNdata(x_data,
x_onednn_data,
input_lod,
......
......@@ -256,8 +256,7 @@ class MultiGRUHandler {
auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N_ * Ti_ * ICs[0]);
if (platform::GetMKLDNNFormat(gru_pds_[{0, L2R}]->src_desc()) ==
dnnl::memory::format_tag::ntc) {
if (isNTC(gru_pds_[{0, L2R}]->src_desc())) {
reorderPPtoNTC(x_data, x_onednn_data, x_lod_, 0, L2R);
} else {
reorderPPtoTNC(x_data, x_onednn_data, x_lod_, 0, L2R);
......@@ -601,16 +600,18 @@ class MultiGRUHandler {
void reorderOutput(std::shared_ptr<dnnl::memory> mem, int layer) {
auto* data = mem->get_data_handle();
auto* hidden_data = to_void_cast(hidden_->mutable_data<Tout>(place_));
if (isNTC(layers_ - 1)) {
if (isNTC(gru_pds_[{layers_ - 1, L2R}]->dst_desc())) {
reorderNTCtoPP(data, hidden_data, layers_ - 1);
} else {
reorderTNCtoPP(data, hidden_data, layers_ - 1);
}
}
bool isNTC(int layer) {
return (platform::GetMKLDNNFormat(gru_pds_[{layer, L2R}]->dst_desc()) ==
dnnl::memory::format_tag::ntc);
bool isNTC(const dnnl::memory::desc& md) {
auto ntc_md = dnnl::memory::desc(
md.dims(), md.data_type(), dnnl::memory::format_tag::ntc);
return md == ntc_md;
}
int getLayers() const { return layers_; }
......
......@@ -30,7 +30,6 @@ using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using phi::DataLayout;
using platform::GetMKLDNNFormat;
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
......
......@@ -29,7 +29,6 @@ using dnnl::stream;
using framework::DDim;
using framework::ExecutionContext;
using LoDTensor = phi::DenseTensor;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
......
......@@ -25,7 +25,6 @@ using dnnl::reorder;
using dnnl::resampling_forward;
using dnnl::stream;
using phi::DataLayout;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
template <typename T = float>
......
......@@ -16,7 +16,6 @@ limitations under the License. */
namespace {
using dnnl::memory;
using paddle::framework::ExecutionContext;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNFormatForSize;
......
......@@ -20,7 +20,6 @@ namespace operators {
using dnnl::memory;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
......
......@@ -29,7 +29,6 @@ using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using phi::DataLayout;
using platform::GetMKLDNNFormat;
template <typename T>
class QuantOpKernel : public framework::OpKernel<T> {
......
......@@ -30,7 +30,6 @@ enum class ReshapeKernelOpName {
namespace paddle {
namespace operators {
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
static std::vector<int> extract_shape(
......@@ -83,13 +82,13 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), platform::to_void_cast(x->data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, getPlainFormatTag(x), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
......@@ -97,9 +96,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
out->Resize(out_dims);
out->set_layout(phi::DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims))));
out->set_mem_desc(
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims)));
}
void InferInOutShape(const framework::ExecutionContext& ctx,
......@@ -358,20 +356,18 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->Resize(dx_dims);
dx->set_layout(phi::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(dx_dims))));
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(dx_dims));
}
void InferOutputShapeInGrad(const framework::ExecutionContext& ctx,
......
......@@ -201,164 +201,6 @@ inline void Reorder(dnnl::memory src,
astream.wait();
}
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return dnnl::memory::format_tag::nc;
} else {
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return dnnl::memory::format_tag::ntc;
} else {
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else {
return dnnl::memory::format_tag::dcab;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde4b;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return dnnl::memory::format_tag::undef;
}
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
auto mem_desc = memory.get_desc();
return GetMKLDNNFormat(mem_desc);
}
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
switch (tensor_rank) {
case 1:
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// NOTE:
// GetMKLDNNFormat function is here temporarily. It is
// needed because without them forward declaration was causing an error when
// building with "-DWITH_TESTING=ON". This file will be deleted after completing
// md-related refactoring
namespace paddle {
namespace platform {
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return dnnl::memory::format_tag::nc;
} else {
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return dnnl::memory::format_tag::ntc;
} else {
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return dnnl::memory::format_tag::abcd;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::acdb;
} else if (strides[0] >= strides[1] && strides[1] >= strides[3] &&
strides[3] >= strides[2]) {
return dnnl::memory::format_tag::abdc;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return dnnl::memory::format_tag::cdba;
} else {
return dnnl::memory::format_tag::dcab;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
} else if (ndims == 5) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde16b;
}
}
}
} else if (ndims == 6) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
// DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
// std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
// std::cout<<"NDIMS: "<<ndims<<std::endl;
// std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
// for (int i=0;i<ndims;++i) {
// std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
// }
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return dnnl::memory::format_tag::undef;
}
} // namespace platform
} // namespace paddle
......@@ -860,6 +860,16 @@ class ReorderOneDNNHandler {
}
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
DenseTensor* output,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& strides,
Place place) {
auto dst_md = dnnl::memory::desc(dims, dtype_dst_, strides);
auto dst_data = output->mutable_data(place, ptype_dst_, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
DenseTensor* output,
const std::vector<int64_t>& dims,
......
......@@ -58,7 +58,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) {
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
}
......@@ -70,7 +69,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
std::move(CopyStorageProperties(other.storage_properties_));
inplace_version_counter_ = other.inplace_version_counter_;
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
return *this;
......@@ -82,7 +80,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
storage_properties_ = std::move(other.storage_properties_);
std::swap(inplace_version_counter_, other.inplace_version_counter_);
#ifdef PADDLE_WITH_MKLDNN
format_ = other.format_;
mem_desc_ = other.mem_desc_;
#endif
return *this;
......
......@@ -274,16 +274,6 @@ In the final state, we should come up with a MKLDNN_Tensor and move the
following codes there.
*/
#ifdef PADDLE_WITH_MKLDNN
/**
* @brief the detail format of memory block which have layout as kMKLDNN
*
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/
dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef;
/// \brief memory descriptor of tensor which have layout set as kMKLDNN
dnnl::memory::desc mem_desc_;
#endif
......
......@@ -123,12 +123,6 @@ inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
meta_.layout = DataLayout::kMKLDNN;
}
dnnl::memory::format_tag format() const;
inline void set_format(const dnnl::memory::format_tag format) {
format_ = format;
}
#endif
/* ------------------------------ */
......
......@@ -19,10 +19,6 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_utils.h"
#endif
namespace phi {
/* --------------------------- */
/* From phi::DenseTensor */
......@@ -348,16 +344,7 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
}
#ifdef PADDLE_WITH_MKLDNN
dnnl::memory::desc DenseTensor::mem_desc() const {
return mem_desc_ ? mem_desc_
: dnnl::memory::desc(phi::vectorize(meta_.dims),
phi::TransToOneDNNDataType(meta_.dtype),
format_);
}
dnnl::memory::format_tag DenseTensor::format() const {
return mem_desc_ ? paddle::platform::GetMKLDNNFormat(mem_desc_) : format_;
}
dnnl::memory::desc DenseTensor::mem_desc() const { return mem_desc_; }
#endif
// NOTE: For historical reasons, this interface has a special behavior,
......@@ -373,7 +360,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) {
storage_properties_ =
std::move(CopyStorageProperties(src.storage_properties_));
#ifdef PADDLE_WITH_MKLDNN
format_ = src.format_;
mem_desc_ = src.mem_desc_;
#endif
return *this;
......
......@@ -19,6 +19,23 @@
namespace phi {
const std::vector<int64_t> get_slice_strides(
const std::vector<int64_t>& out_vec_dims,
const dnnl::memory::desc& full_md,
int axis) {
auto strides = full_md.data.format_desc.blocking.strides;
auto ndims = full_md.data.ndims;
auto full_dims = full_md.data.dims;
auto splitted_stride = strides[axis];
std::vector<int64_t> slice_strides(ndims, splitted_stride);
for (int16_t i = 0; i < ndims; ++i) {
slice_strides[i] = strides[i] > splitted_stride
? (strides[i] / full_dims[axis]) * out_vec_dims[axis]
: strides[i];
}
return slice_strides;
}
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -49,7 +66,10 @@ void SplitKernel(const Context& dev_ctx,
out_vec_dims, offset, reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out[i], out_vec_dims, x.format(), dev_ctx.GetPlace());
out[i],
out_vec_dims,
get_slice_strides(out_vec_dims, x.mem_desc(), axis),
dev_ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册