未验证 提交 416e0de7 编写于 作者: J Jacek Czaja 提交者: GitHub

updating mul and matmul with set_mem_desc (#45624)

* - mul & matmul changes

- fix

- bs16 correction of strides

* - cosmetic fixes

* - lint

* - fix

* - fix

* - format -> mem_desc

* - fix

* - fix

* - fix

* - fix

* - fix
上级 5022dd9b
......@@ -214,10 +214,7 @@ class MatMulMKLDNNHandler
}
astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
......@@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
}
template <typename T>
......@@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(squeezed_dims)));
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
......@@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
template <typename T>
......@@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_format(x.format());
dx->set_mem_desc(x.mem_desc());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_format(y.format());
dy->set_mem_desc(y.mem_desc());
}
}
}
......
......@@ -221,7 +221,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims());
x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc));
x_tmp.set_mem_desc(dst_mdesc);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
......@@ -235,11 +235,7 @@ class MulPrimitiveFactory {
const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
out->set_mem_desc(output_->get_desc());
}
template <typename T>
......@@ -272,7 +268,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc));
output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
}
......@@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
if (out_dims.size() != 2) {
out->Resize(out_dims);
}
out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize(out_dims.size(),
MKLDNNMemoryFormat::nchw));
auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md(
mul.get_primitive_desc(), dnnl_query_dst_md, 0));
out->set_mem_desc(in_md.reshape(phi::vectorize<int64_t>(out->dims())));
}
};
......@@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
// plain output formats are enforced inside handler
out->set_format(platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw));
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
private:
......
......@@ -196,7 +196,8 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) {
// TODO(jczaja): Why not for int8??
if (!IsInt8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
......
......@@ -121,8 +121,10 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
}
out->set_layout(DataLayout::ONEDNN);
out->set_format(out_format);
dnnl::memory::desc out_mem_desc(vectorize<int64_t>(out->dims()),
funcs::ToOneDNNDataType(x.dtype()),
out_format);
out->set_mem_desc(out_mem_desc);
} else if (src_layout == DataLayout::ONEDNN &&
dst_layout != DataLayout::ONEDNN) {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册