未验证 提交 19746835 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring: Added support for md in transpose (#46620)

* added transpose

* CI fix

* fix for transpose

* fix after review
上级 a579e523
...@@ -24,70 +24,6 @@ namespace operators { ...@@ -24,70 +24,6 @@ namespace operators {
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
using framework::DataLayout; using framework::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
template <typename T> template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -98,37 +34,84 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -98,37 +34,84 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Transpose must use CPUPlace")); "Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = transpose_axis.size();
auto* input = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* output = ctx.Output<phi::DenseTensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*input, input->place(), output); framework::TensorCopy(*x, x->place(), out);
output->set_format(input->format()); out->set_mem_desc(x->mem_desc());
return; return;
} }
auto nchw_tz = phi::vectorize<int64_t>(input->dims()); auto x_vec_dims = phi::vectorize(x->dims());
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine); framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data)); x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto dst_md =
transpose_p->execute( dnnl::memory::desc(x_vec_dims,
astream, *transpose_src_memory_p, *transpose_dst_memory_p); x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dnnl_engine, dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kNCHW); out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
output->set_format(MKLDNNMemoryFormat::undef); TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
} }
}; };
...@@ -140,43 +123,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,43 +123,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace")); "Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X")); const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (!x_grad) return; auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!dx) return;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
int ndims = axis.size(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
int ndims = transpose_axis.size();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*out_grad, out_grad->place(), x_grad); framework::TensorCopy(*dout, dout->place(), dx);
x_grad->set_format(out_grad->format()); dx->set_mem_desc(dout->mem_desc());
return; return;
} }
for (size_t i = 0; i < axis.size(); i++) { auto dout_vec_dims = phi::vectorize(dout->dims());
reversed_axis[axis[i]] = i;
}
const T* out_grad_data = out_grad->data<T>(); framework::proto::VarType::Type dout_paddle_type =
x_grad->mutable_data<T>(ctx.GetPlace()); framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims()); platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_dst_memory_p =
out_grad->format(), platform::to_void_cast<T>(out_grad_data)); reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
transpose_p->execute( reorder_src_memory_p);
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
dx->set_mem_desc(
reorder_dst_memory_p->get_desc().permute_axes(transpose_axis));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册