未验证 提交 19746835 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring: Added support for md in transpose (#46620)

* added transpose

* CI fix

* fix for transpose

* fix after review
上级 a579e523
......@@ -24,70 +24,6 @@ namespace operators {
using Tensor = phi::DenseTensor;
using framework::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -98,37 +34,84 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto* input = ctx.Input<phi::DenseTensor>("X");
auto* output = ctx.Output<phi::DenseTensor>("Out");
const T* input_data = input->data<T>();
const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = transpose_axis.size();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ndims == 1) {
framework::TensorCopy(*input, input->place(), output);
output->set_format(input->format());
framework::TensorCopy(*x, x->place(), out);
out->set_mem_desc(x->mem_desc());
return;
}
auto nchw_tz = phi::vectorize<int64_t>(input->dims());
auto x_vec_dims = phi::vectorize(x->dims());
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
auto dst_md =
dnnl::memory::desc(x_vec_dims,
x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dnnl_engine, dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output->set_layout(DataLayout::kNCHW);
output->set_format(MKLDNNMemoryFormat::undef);
out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
}
};
......@@ -140,43 +123,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
if (!x_grad) return;
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!dx) return;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
int ndims = transpose_axis.size();
if (ndims == 1) {
framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
x_grad->set_format(out_grad->format());
framework::TensorCopy(*dout, dout->place(), dx);
dx->set_mem_desc(dout->mem_desc());
return;
}
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
auto dout_vec_dims = phi::vectorize(dout->dims());
const T* out_grad_data = out_grad->data<T>();
x_grad->mutable_data<T>(ctx.GetPlace());
framework::proto::VarType::Type dout_paddle_type =
framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims());
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad->format(), platform::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_mem_desc(
reorder_dst_memory_p->get_desc().permute_axes(transpose_axis));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册