未验证 提交 416e0de7 编写于 作者: J Jacek Czaja 提交者: GitHub

updating mul and matmul with set_mem_desc (#45624)

* - mul & matmul changes

- fix

- bs16 correction of strides

* - cosmetic fixes

* - lint

* - fix

* - fix

* - format -> mem_desc

* - fix

* - fix

* - fix

* - fix

* - fix
上级 5022dd9b
...@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler ...@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler
} }
astream.wait(); astream.wait();
auto format = out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
} }
std::shared_ptr<dnnl::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
...@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); // TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
out->set_format(format); // permute
out->set_layout(DataLayout::kMKLDNN); if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
} }
template <typename T> template <typename T>
...@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat( dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
dst_memory_p->get_desc().reshape(squeezed_dims)));
} }
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims, std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
...@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad( ...@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_mem_desc(
out->set_format(platform::GetMKLDNNFormat( dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
} }
template <typename T> template <typename T>
...@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const { ...@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
if (dx) { if (dx) {
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
dx->Resize(dx_dims); dx->Resize(dx_dims);
dx->set_format(x.format()); dx->set_mem_desc(x.mem_desc());
} }
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
dy->Resize(dy_dims); dy->Resize(dy_dims);
dy->set_format(y.format()); dy->set_mem_desc(y.mem_desc());
} }
} }
} }
......
...@@ -221,7 +221,7 @@ class MulPrimitiveFactory { ...@@ -221,7 +221,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>())); to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims()); x_tmp.Resize(data->dims());
x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc)); x_tmp.set_mem_desc(dst_mdesc);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else { } else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
...@@ -235,11 +235,7 @@ class MulPrimitiveFactory { ...@@ -235,11 +235,7 @@ class MulPrimitiveFactory {
const Tensor *in) { const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>())); x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
out->set_mem_desc(output_->get_desc());
if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
} }
template <typename T> template <typename T>
...@@ -272,7 +268,7 @@ class MulPrimitiveFactory { ...@@ -272,7 +268,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_desc.get_size(); auto buffer_size = dst_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size); OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc)); output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data)); return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
} }
...@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { ...@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
if (out_dims.size() != 2) { if (out_dims.size() != 2) {
out->Resize(out_dims); out->Resize(out_dims);
} }
out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize(out_dims.size(), auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md(
MKLDNNMemoryFormat::nchw)); mul.get_primitive_desc(), dnnl_query_dst_md, 0));
out->set_mem_desc(in_md.reshape(phi::vectorize<int64_t>(out->dims())));
} }
}; };
...@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); // This kernel is flattening dims so then we need to unflattened version
// plain output formats are enforced inside handler // that should be set in out reshape require plain layout, but
out->set_format(platform::MKLDNNFormatForSize( // MatmulV2MKLDNNHanlder enforces one so it should work
out->dims().size(), dnnl::memory::format_tag::nchw)); out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
} }
private: private:
......
...@@ -196,7 +196,8 @@ class MatMulV2MKLDNNHandler ...@@ -196,7 +196,8 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) { // TODO(jczaja): Why not for int8??
if (!IsInt8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims); out_strides = FakeTransposeStrides(out_ddims);
} }
......
...@@ -121,8 +121,10 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, ...@@ -121,8 +121,10 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
OneDNNContext::tls().set_cur_paddle_data_layout(src_layout); OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
} }
out->set_layout(DataLayout::ONEDNN); dnnl::memory::desc out_mem_desc(vectorize<int64_t>(out->dims()),
out->set_format(out_format); funcs::ToOneDNNDataType(x.dtype()),
out_format);
out->set_mem_desc(out_mem_desc);
} else if (src_layout == DataLayout::ONEDNN && } else if (src_layout == DataLayout::ONEDNN &&
dst_layout != DataLayout::ONEDNN) { dst_layout != DataLayout::ONEDNN) {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册