未验证 提交 493fbfd7 编写于 作者: J Jacek Czaja 提交者: GitHub

Update of PHI transpose_grad (#47311)

* - halfway transforming transpose grad

- Fixes

- buildable

* - lint

* rerunning the process
上级 77dbb318
...@@ -1085,69 +1085,5 @@ class ClipOneDNNHandler ...@@ -1085,69 +1085,5 @@ class ClipOneDNNHandler
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
} }
}; };
template <typename T>
class TransposeOneDNNHandler {
public:
TransposeOneDNNHandler(const OneDNNContext& dev_ctx,
std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dev_ctx_(dev_ctx),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const OneDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != OneDNNMemoryFormat::nchw
? OneDNNMemDesc(dims_, OneDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output,
Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = dev_ctx_.Alloc<T>(output);
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(nchw_tz, OneDNNGetDataType<T>(), strides);
return mem_d;
}
private:
const OneDNNContext& dev_ctx_;
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -31,35 +31,33 @@ void TransposeGradKernel(const Context& dev_ctx, ...@@ -31,35 +31,33 @@ void TransposeGradKernel(const Context& dev_ctx,
if (!x_grad) return; if (!x_grad) return;
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
std::vector<int> reversed_axis(axis);
if (axis.size() == 1) { if (axis.size() == 1) {
paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad); paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad);
x_grad->set_format(out_grad.format()); x_grad->set_mem_desc(out_grad.mem_desc());
return; return;
} }
for (size_t i = 0; i < axis.size(); i++) { std::vector<int64_t> out_grad_tz = vectorize(out_grad.dims());
reversed_axis[axis[i]] = i; funcs::ReorderOneDNNHandler reorder_handler(
} out_grad_tz,
out_grad.dtype(),
funcs::ToOneDNNDataType(out_grad.dtype()),
onednn_engine);
const T* out_grad_data = out_grad.data<T>(); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dev_ctx.template Alloc<T>(x_grad); out_grad.mem_desc(), funcs::to_void_cast(out_grad.data<T>()));
auto nchw_tz = vectorize<int64_t>(out_grad.dims());
funcs::TransposeOneDNNHandler<T> handler( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dev_ctx, nchw_tz, reversed_axis, onednn_engine); x_grad, out_grad.mem_desc(), dev_ctx.GetPlace());
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
out_grad.format(), funcs::to_void_cast<T>(out_grad_data)); reorder_src_memory_p);
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, dev_ctx.GetPlace());
auto transpose_p =
handler.AcquireTranspose(transpose_dst_memory_p, transpose_src_memory_p);
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
transpose_p->execute( reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
astream.wait(); astream.wait();
x_grad->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(axis));
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册