diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc deleted file mode 100644 index b8638ab17c7dbc56c26d77a992e8791800d6d363..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ /dev/null @@ -1,502 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace operators { - -using framework::DDim; -using framework::ExecutionContext; - -using phi::OneDNNContext; -using platform::MatMulV2MKLDNNHandler; - -using dnnl::inner_product_forward; -using dnnl::memory; -using dnnl::prop_kind; -using dnnl::stream; - -template -class MulPrimitiveFactory { - public: - explicit MulPrimitiveFactory(const dnnl::engine &engine) : engine_(engine) {} - - inner_product_forward CreateMulPrimitive(const Tensor *x_input, - const Tensor *y_input, - Tensor *output, - const ExecutionContext &ctx) { - /* check data format and reorder if need */ - int x_num_col_dims = ctx.Attr("x_num_col_dims"); - int y_num_col_dims = ctx.Attr("y_num_col_dims"); - - // TODO(intel-minghui) : Remove the restriction that only supports Input(Y) - // as weights - PADDLE_ENFORCE_EQ( - (std::is_same::value), - true, - platform::errors::InvalidArgument( - "Input(Y) must be fp32 data type since only fp32 data type is " - "supported in the current design of MKLDNN INT8.")); - - auto x_matrix = UpdateDataFormat(x_input, x_num_col_dims, ctx); - auto y_matrix = UpdateDataFormat(y_input, y_num_col_dims, ctx); - - auto output_dim = output->dims(); - if (output_dim.size() != 2) { - output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - - if (mul_) { - UpdateDataPointers(ctx, output, &x_matrix); - Execute(); - return *(mul_); - } - - auto src_desc = CreateMemDescriptor(&x_matrix, OneDNNMemoryFormat::nc); - x_input_ = CreateMemory(src_desc, &x_matrix); - - if (is_int8_) { - const auto trans_y = TransposeInputY(&y_matrix); - auto scale_y = ctx.Attr>("scale_y"); - y_input_ = QuantInputY(trans_y, scale_y); - } else { - y_input_ = TransposeInputY(&y_matrix); - } - - auto dst_desc = CreateMemDescriptor(output, OneDNNMemoryFormat::any); - - mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx); - Execute(); - return *(mul_); - } - - private: - memory ReorderWithScale(const memory::desc &src_desc, - const memory::desc &dst_desc, - void *src_data, - const std::vector &scale) { - auto mask = scale.size() > 1 ? 1 : 0; - dnnl::primitive_attr attr; - attr.set_output_scales(mask, scale); - - auto src_mem = memory(src_desc, engine_, src_data); - auto dst_mem = memory(dst_desc, engine_); - - auto reorder_pd = dnnl::reorder::primitive_desc(src_mem, dst_mem, attr); - - auto reorder = dnnl::reorder(reorder_pd); - - auto &astream = OneDNNContext::tls().get_stream(); - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, src_mem, dst_mem); - astream.wait(); - } - - return dst_mem; - } - - memory QuantInputY(memory input_y, const std::vector &scale_y) { - const auto &dims = input_y.get_desc().data.dims; - auto ndims = input_y.get_desc().data.ndims; - auto y_dims = std::vector(dims, dims + ndims); - - auto user_y_desc = CreateMemDescriptor(y_dims, OneDNNMemoryFormat::oi); - auto y_desc = CreateMemDescriptor(y_dims, OneDNNMemoryFormat::oi); - - return ReorderWithScale( - user_y_desc, y_desc, input_y.get_data_handle(), scale_y); - } - - dnnl::primitive_attr CreateMulAttr(const ExecutionContext &ctx, - bool force_fp32_output) { - dnnl::primitive_attr mul_attr; - - auto scale_y_data = ctx.Attr>("scale_y"); - auto scale_x_data = ctx.Attr("scale_x"); - auto scale_out_data = - force_fp32_output ? 1.0f : ctx.Attr("scale_out"); - - bool is_multi_channel = scale_y_data.size() > 1; - int count = is_multi_channel ? scale_y_data.size() : 1; - std::vector output_shift_scale(count); - for (int i = 0; i < count; i++) { - if (scale_y_data[i] == 0.0) - output_shift_scale[i] = scale_out_data; - else - output_shift_scale[i] = - scale_out_data / (scale_x_data * scale_y_data[i]); - } - int mul_mask = is_multi_channel ? 1 : 0; - mul_attr.set_output_scales(mul_mask, output_shift_scale); - - return mul_attr; - } - - inner_product_forward CreateMulPrimitive(const memory &x_memory, - const memory &y_memory, - const memory::desc &dst_desc, - Tensor *output, - const ExecutionContext &ctx) { - const auto x_desc = x_memory.get_desc(); - const auto y_desc = y_memory.get_desc(); - inner_product_forward::primitive_desc mul_prim_desc; - - const auto &mul_desc = inner_product_forward::desc( - prop_kind::forward, x_desc, y_desc, dst_desc); - - if (is_int8_) { - bool force_fp32_output = ctx.Attr("force_fp32_output"); - auto mul_attr = CreateMulAttr(ctx, force_fp32_output); - mul_prim_desc = - inner_product_forward::primitive_desc(mul_desc, mul_attr, engine_); - } else { - mul_prim_desc = inner_product_forward::primitive_desc(mul_desc, engine_); - } - - output_ = CreateDstMemory(mul_prim_desc, ctx, output); - - return inner_product_forward(mul_prim_desc); - } - - void Execute() { - auto &astream = OneDNNContext::tls().get_stream(); - (*mul_).execute(astream, - {{DNNL_ARG_SRC, *x_input_}, - {DNNL_ARG_WEIGHTS, *y_input_}, - {DNNL_ARG_DST, *output_}}); - astream.wait(); - } - - template - Tensor UpdateDataFormat(const Tensor *data, - int num_col_dims, - const ExecutionContext &ctx) { - Tensor x_tmp; - Tensor data_matrix; - // This code is enforcing plain (non-blocked) memory arrangement - // in order to flatten (reduce dimensionality) of Tensor later - auto src_mdesc = data->mem_desc(); - auto dst_mdesc = - data->dims().size() >= 4 - ? (data->dims().size() == 5 - ? CreateMemDescriptor(data, OneDNNMemoryFormat::ncdhw) - : CreateMemDescriptor(data, OneDNNMemoryFormat::nchw)) - : src_mdesc; - - if (src_mdesc != dst_mdesc) { - x_tmp.mutable_data(ctx.GetPlace(), data->memory_size()); - - Reorder(src_mdesc, - dst_mdesc, - phi::funcs::to_void_cast(data->data()), - phi::funcs::to_void_cast(x_tmp.data())); - - x_tmp.Resize(data->dims()); - x_tmp.set_mem_desc(dst_mdesc); - data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); - } else { - data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); - } - - return data_matrix; - } - - void UpdateDataPointers(const ExecutionContext &ctx, - Tensor *out, - const Tensor *in) { - x_input_->set_data_handle(phi::funcs::to_void_cast(in->data())); - output_->set_data_handle(out->mutable_data(ctx.GetPlace())); - out->set_mem_desc(output_->get_desc()); - } - - template - memory::desc CreateMemDescriptor( - const Tensor *tensor, - OneDNNMemoryFormat format, - memory::data_type type = phi::funcs::OneDNNGetDataType()) { - auto dims = phi::vectorize(tensor->dims()); - return phi::funcs::OneDNNMemDesc(dims, type, format); - } - - template - memory::desc CreateMemDescriptor( - const std::vector &dims, - OneDNNMemoryFormat format, - memory::data_type type = phi::funcs::OneDNNGetDataType()) { - return phi::funcs::OneDNNMemDesc(dims, type, format); - } - - template - memory CreateMemory(const memory::desc &desc, const Tensor *tensor) { - return memory( - desc, engine_, phi::funcs::to_void_cast(tensor->data())); - } - - memory CreateDstMemory( - const inner_product_forward::primitive_desc &mul_prim_desc, - const ExecutionContext &ctx, - Tensor *output) { - auto dst_desc = mul_prim_desc.dst_desc(); - auto buffer_size = dst_desc.get_size(); - - OT *output_data = output->mutable_data(ctx.GetPlace(), buffer_size); - output->set_mem_desc(dst_desc); - return memory(dst_desc, engine_, phi::funcs::to_void_cast(output_data)); - } - - memory Reorder(const memory::desc &src_desc, - const memory::desc &dst_desc, - void *src_data, - void *dst_data = NULL) { - auto src_mem = memory(src_desc, engine_, src_data); - auto dst_mem = dst_data ? memory(dst_desc, engine_, dst_data) - : memory(dst_desc, engine_); - - auto reorder = dnnl::reorder(src_mem, dst_mem); - - auto &astream = OneDNNContext::tls().get_stream(); - { - platform::RecordEvent record_reorder( - "int_reorder", - platform::TracerEventType::UserDefined, - 2, - platform::EventRole::kUniqueOp); - reorder.execute(astream, src_mem, dst_mem); - astream.wait(); - } - - return dst_mem; - } - - memory TransposeInputY(const Tensor *input_y) { - auto dims = phi::vectorize(input_y->dims()); - std::swap(dims[0], dims[1]); // Correct output dimensions - auto src_desc = CreateMemDescriptor(dims, OneDNNMemoryFormat::io); - auto dst_desc = CreateMemDescriptor(dims, OneDNNMemoryFormat::oi); - return Reorder( - src_desc, dst_desc, phi::funcs::to_void_cast(input_y->data())); - } - - const dnnl::engine &engine_; - paddle::optional x_input_; - paddle::optional y_input_; - paddle::optional output_; - paddle::optional mul_; - static constexpr bool is_int8_ = - std::is_same::value || std::is_same::value; -}; - -/* OT: output data type */ -template -std::shared_ptr> GetPrimitiveFactory( - const OneDNNContext &dev_ctx, - const ExecutionContext &ctx, - const Tensor *input_x, - const Tensor *input_y, - const dnnl::engine &mkldnn_engine) { - std::string key = - phi::funcs::CreateKey(dev_ctx, - framework::TransToProtoVarType(input_x->dtype()), - phi::vectorize(input_x->dims()), - framework::TransToProtoVarType(input_y->dtype()), - phi::vectorize(input_y->dims()), - ctx.OutputName("Out")); - key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - - auto prim_creator = std::static_pointer_cast>( - dev_ctx.GetBlob(key)); - - if (prim_creator == nullptr) { - prim_creator = - std::make_shared>(mkldnn_engine); - dev_ctx.SetBlob(key, prim_creator); - } - - return prim_creator; -} - -/* XT: input x data type, YT: input y data type */ -template -inner_product_forward GetMulPrimitive(const OneDNNContext &dev_ctx, - const ExecutionContext &ctx, - const Tensor *input_x, - const Tensor *input_y, - Tensor *output, - const dnnl::engine &mkldnn_engine) { - constexpr bool is_int8 = - std::is_same::value || std::is_same::value; - bool force_fp32_output = ctx.Attr("force_fp32_output"); - - if (is_int8 && !force_fp32_output) { - return GetPrimitiveFactory( - dev_ctx, ctx, input_x, input_y, mkldnn_engine) - ->CreateMulPrimitive(input_x, input_y, output, ctx); - - } else { - return GetPrimitiveFactory( - dev_ctx, ctx, input_x, input_y, mkldnn_engine) - ->CreateMulPrimitive(input_x, input_y, output, ctx); - } -} - -/* XT: input x data type */ -template -class MulMKLDNNINT8Kernel : public framework::OpKernel { - public: - void Compute(const ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL Mul must use CPUPlace")); - OneDNNContext::tls().log_lib_version(); - auto &dev_ctx = ctx.template device_context(); - auto &mkldnn_engine = dev_ctx.GetEngine(); - - const Tensor *x = ctx.Input("X"); - const Tensor *y = ctx.Input("Y"); - Tensor *out = ctx.Output("Out"); - auto out_dims = out->dims(); - - auto mul = - GetMulPrimitive(dev_ctx, ctx, x, y, out, mkldnn_engine); - - if (out_dims.size() != 2) { - out->Resize(out_dims); - } - - auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md( - mul.get_primitive_desc(), dnnl_query_dst_md, 0)); - out->set_mem_desc(in_md.reshape(phi::vectorize(out->dims()))); - } -}; - -template -class MulMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } - - protected: - void ExecuteMatMul(const ExecutionContext &ctx, - const OneDNNContext &dev_ctx, - const dnnl::engine &onednn_engine, - const platform::Place &cpu_place, - const Tensor *x, - const std::vector &x_dims, - bool trans_x, - const Tensor *y, - const std::vector &y_dims, - bool trans_y, - Tensor *out) const { - static const std::vector vec_placeholder; - MatMulV2MKLDNNHandler handler(ctx, - onednn_engine, - ctx.GetPlace(), - x_dims, - trans_x, - y_dims, - trans_y, - false, - vec_placeholder, - vec_placeholder); - - const auto src_memory_p = handler.AcquireSrcMemory(x); - const auto weights_memory_p = handler.AcquireWeightsMemory(y); - const auto dst_memory_p = handler.AcquireDstMemory(out); - - auto matmul_p = handler.AcquireForwardPrimitive(); - - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - auto &astream = OneDNNContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); - astream.wait(); - - // This kernel is flattening dims so then we need to unflattened version - // that should be set in out reshape require plain layout, but - // MatmulV2MKLDNNHanlder enforces one so it should work - out->set_mem_desc( - dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); - } - - private: - void RunKernel(const ExecutionContext &ctx) const { - const auto &dev_ctx = ctx.template device_context(); - const auto &onednn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - const auto *y = ctx.Input("Y"); - auto *out = ctx.Output("Out"); - - int x_num_col_dims = ctx.Attr("x_num_col_dims"); - int y_num_col_dims = ctx.Attr("y_num_col_dims"); - - const Tensor x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : *x; - const Tensor y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : *y; - - // adding mb dim because MatMulV2 handler needs it - std::vector y_dims(3, 1); - std::vector x_dims(3, 1); - - y_dims[1] = y_matrix.dims()[0]; - y_dims[2] = y_matrix.dims()[1]; - - x_dims[1] = x_matrix.dims()[0]; - x_dims[2] = x_matrix.dims()[1]; - - ExecuteMatMul(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - &x_matrix, - x_dims, - false, - &y_matrix, - y_dims, - false, - out); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(mul, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::MulMKLDNNINT8Kernel, - ops::MulMKLDNNINT8Kernel, - ops::MulMKLDNNKernel, - ops::MulMKLDNNKernel); diff --git a/paddle/phi/kernels/onednn/matmul_kernel.cc b/paddle/phi/kernels/onednn/matmul_kernel.cc index 30a1735c5184aadb381e294d748a0aa5711b5541..c820e738f09348b7d207dbd81e33fcb40b615d98 100644 --- a/paddle/phi/kernels/onednn/matmul_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_kernel.cc @@ -12,11 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/core/kernel_registry.h" +using dnnl::engine; +using dnnl::inner_product_forward; +using dnnl::memory; +using dnnl::prop_kind; +using dnnl::stream; +using paddle::framework::ReshapeToMatrix; + namespace phi { DDim GetDimsForInput(const OneDNNContext &dev_ctx, @@ -152,6 +161,418 @@ void MatmulKernel(const Context &dev_ctx, } } +template +class MulPrimitiveFactory { + public: + explicit MulPrimitiveFactory(const engine &engine) : engine_(engine) {} + + inner_product_forward CreateMulPrimitive(const DenseTensor *x_input, + const DenseTensor *y_input, + DenseTensor *output, + int x_num_col_dims, + int y_num_col_dims, + const OneDNNContext &dev_ctx) { + // TODO(intel-minghui) : Remove the restriction that only supports Input(Y) + // as weights + PADDLE_ENFORCE_EQ( + (std::is_same::value), + true, + errors::InvalidArgument( + "Input(Y) must be fp32 data type since only fp32 data type is " + "supported in the current design of OneDNN INT8.")); + + /* check data format and reorder if need */ + auto x_matrix = UpdateDataFormat(x_input, x_num_col_dims, dev_ctx); + auto y_matrix = UpdateDataFormat(y_input, y_num_col_dims, dev_ctx); + + auto output_dim = output->dims(); + if (output_dim.size() != 2) { + output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + if (mul_) { + UpdateDataPointers(dev_ctx, output, &x_matrix); + Execute(); + return *(mul_); + } + + auto src_desc = + CreateMemDescriptor(&x_matrix, funcs::OneDNNMemoryFormat::nc); + x_input_ = CreateMemory(src_desc, &x_matrix); + + if (is_int8_) { + const auto trans_y = TransposeInputY(&y_matrix); + auto scale_y = dev_ctx.HasDnnAttr("scale_y") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("scale_y")) + : std::vector(); + y_input_ = QuantInputY(trans_y, scale_y); + } else { + y_input_ = TransposeInputY(&y_matrix); + } + + auto dst_desc = + CreateMemDescriptor(output, funcs::OneDNNMemoryFormat::any); + + mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, dev_ctx); + Execute(); + return *(mul_); + } + + private: + memory ReorderWithScale(const memory::desc &src_desc, + const memory::desc &dst_desc, + void *src_data, + const std::vector &scale) { + auto mask = scale.size() > 1 ? 1 : 0; + dnnl::primitive_attr attr; + attr.set_output_scales(mask, scale); + + auto src_mem = memory(src_desc, engine_, src_data); + auto dst_mem = memory(dst_desc, engine_); + + auto reorder_pd = dnnl::reorder::primitive_desc(src_mem, dst_mem, attr); + + auto reorder = dnnl::reorder(reorder_pd); + + auto &astream = OneDNNContext::tls().get_stream(); + { + paddle::platform::RecordEvent record_reorder( + "int_reorder", + paddle::platform::TracerEventType::UserDefined, + 2, + paddle::platform::EventRole::kUniqueOp); + reorder.execute(astream, src_mem, dst_mem); + astream.wait(); + } + + return dst_mem; + } + + memory QuantInputY(memory input_y, const std::vector &scale_y) { + const auto &dims = input_y.get_desc().data.dims; + auto ndims = input_y.get_desc().data.ndims; + auto y_dims = std::vector(dims, dims + ndims); + + auto user_y_desc = + CreateMemDescriptor(y_dims, funcs::OneDNNMemoryFormat::oi); + auto y_desc = + CreateMemDescriptor(y_dims, funcs::OneDNNMemoryFormat::oi); + + return ReorderWithScale( + user_y_desc, y_desc, input_y.get_data_handle(), scale_y); + } + + dnnl::primitive_attr CreateMulAttr(const OneDNNContext &dev_ctx, + bool force_fp32_output) { + dnnl::primitive_attr mul_attr; + + auto scale_y_data = dev_ctx.HasDnnAttr("scale_y") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("scale_y")) + : std::vector{1.0}; + auto scale_x_data = + dev_ctx.HasDnnAttr("scale_x") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("scale_x")) + : 1.0f; + auto scale_out = + dev_ctx.HasDnnAttr("scale_out") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("scale_out")) + : 1.0f; + auto scale_out_data = force_fp32_output ? 1.0f : scale_out; + + bool is_multi_channel = scale_y_data.size() > 1; + int count = is_multi_channel ? scale_y_data.size() : 1; + std::vector output_shift_scale(count); + for (int i = 0; i < count; i++) { + if (scale_y_data[i] == 0.0) + output_shift_scale[i] = scale_out_data; + else + output_shift_scale[i] = + scale_out_data / (scale_x_data * scale_y_data[i]); + } + int mul_mask = is_multi_channel ? 1 : 0; + mul_attr.set_output_scales(mul_mask, output_shift_scale); + + return mul_attr; + } + + inner_product_forward CreateMulPrimitive(const memory &x_memory, + const memory &y_memory, + const memory::desc &dst_desc, + DenseTensor *output, + const OneDNNContext &dev_ctx) { + const auto x_desc = x_memory.get_desc(); + const auto y_desc = y_memory.get_desc(); + inner_product_forward::primitive_desc mul_prim_desc; + + const auto &mul_desc = inner_product_forward::desc( + prop_kind::forward, x_desc, y_desc, dst_desc); + + if (is_int8_) { + bool force_fp32_output = + dev_ctx.HasDnnAttr("force_fp32_output") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) + : false; + auto mul_attr = CreateMulAttr(dev_ctx, force_fp32_output); + mul_prim_desc = + inner_product_forward::primitive_desc(mul_desc, mul_attr, engine_); + } else { + mul_prim_desc = inner_product_forward::primitive_desc(mul_desc, engine_); + } + + output_ = CreateDstMemory(mul_prim_desc, dev_ctx, output); + + return inner_product_forward(mul_prim_desc); + } + + void Execute() { + auto &astream = OneDNNContext::tls().get_stream(); + (*mul_).execute(astream, + {{DNNL_ARG_SRC, *x_input_}, + {DNNL_ARG_WEIGHTS, *y_input_}, + {DNNL_ARG_DST, *output_}}); + astream.wait(); + } + + template + DenseTensor UpdateDataFormat(const DenseTensor *data, + int num_col_dims, + const OneDNNContext &dev_ctx) { + DenseTensor x_tmp; + DenseTensor data_matrix; + // This code is enforcing plain (non-blocked) memory arrangement + // in order to flatten (reduce dimensionality) of DenseTensor later + auto src_mdesc = data->mem_desc(); + auto dst_mdesc = data->dims().size() >= 4 + ? (data->dims().size() == 5 + ? CreateMemDescriptor( + data, funcs::OneDNNMemoryFormat::ncdhw) + : CreateMemDescriptor( + data, funcs::OneDNNMemoryFormat::nchw)) + : src_mdesc; + + if (src_mdesc != dst_mdesc) { + dev_ctx.template Alloc(&x_tmp, data->memory_size()); + + Reorder(src_mdesc, + dst_mdesc, + funcs::to_void_cast(data->data()), + funcs::to_void_cast(x_tmp.data())); + + x_tmp.Resize(data->dims()); + x_tmp.set_mem_desc(dst_mdesc); + data_matrix = ReshapeToMatrix(x_tmp, num_col_dims); + } else { + data_matrix = ReshapeToMatrix(*data, num_col_dims); + } + + return data_matrix; + } + + void UpdateDataPointers(const OneDNNContext &dev_ctx, + DenseTensor *out, + const DenseTensor *in) { + x_input_->set_data_handle(funcs::to_void_cast(in->data())); + output_->set_data_handle(dev_ctx.template Alloc(out)); + out->set_mem_desc(output_->get_desc()); + } + + template + memory::desc CreateMemDescriptor( + const DenseTensor *tensor, + funcs::OneDNNMemoryFormat format, + memory::data_type type = funcs::OneDNNGetDataType()) { + auto dims = vectorize(tensor->dims()); + return funcs::OneDNNMemDesc(dims, type, format); + } + + template + memory::desc CreateMemDescriptor( + const std::vector &dims, + funcs::OneDNNMemoryFormat format, + memory::data_type type = funcs::OneDNNGetDataType()) { + return funcs::OneDNNMemDesc(dims, type, format); + } + + template + memory CreateMemory(const memory::desc &desc, const DenseTensor *tensor) { + return memory(desc, engine_, funcs::to_void_cast(tensor->data())); + } + + memory CreateDstMemory( + const inner_product_forward::primitive_desc &mul_prim_desc, + const OneDNNContext &dev_ctx, + DenseTensor *output) { + auto dst_desc = mul_prim_desc.dst_desc(); + auto buffer_size = dst_desc.get_size(); + + OT *output_data = dev_ctx.template Alloc(output, buffer_size); + output->set_mem_desc(dst_desc); + return memory(dst_desc, engine_, funcs::to_void_cast(output_data)); + } + + memory Reorder(const memory::desc &src_desc, + const memory::desc &dst_desc, + void *src_data, + void *dst_data = NULL) { + auto src_mem = memory(src_desc, engine_, src_data); + auto dst_mem = dst_data ? memory(dst_desc, engine_, dst_data) + : memory(dst_desc, engine_); + + auto reorder = dnnl::reorder(src_mem, dst_mem); + + auto &astream = OneDNNContext::tls().get_stream(); + { + paddle::platform::RecordEvent record_reorder( + "int_reorder", + paddle::platform::TracerEventType::UserDefined, + 2, + paddle::platform::EventRole::kUniqueOp); + reorder.execute(astream, src_mem, dst_mem); + astream.wait(); + } + + return dst_mem; + } + + memory TransposeInputY(const DenseTensor *input_y) { + auto dims = vectorize(input_y->dims()); + std::swap(dims[0], dims[1]); // Correct output dimensions + auto src_desc = + CreateMemDescriptor(dims, funcs::OneDNNMemoryFormat::io); + auto dst_desc = + CreateMemDescriptor(dims, funcs::OneDNNMemoryFormat::oi); + return Reorder( + src_desc, dst_desc, funcs::to_void_cast(input_y->data())); + } + + const engine &engine_; + paddle::optional x_input_; + paddle::optional y_input_; + paddle::optional output_; + paddle::optional mul_; + static constexpr bool is_int8_ = funcs::is_int8(); +}; + +/* OT: output data type */ +template +std::shared_ptr> GetPrimitiveFactory( + const OneDNNContext &dev_ctx, + const DenseTensor *input_x, + const DenseTensor *input_y, + const engine &onednn_engine) { + std::string key = funcs::CreateKey(dev_ctx, + TransToProtoVarType(input_x->dtype()), + vectorize(input_x->dims()), + TransToProtoVarType(input_y->dtype()), + vectorize(input_y->dims()), + dev_ctx.GetOutputsName("Out")[0]); + key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + + auto prim_creator = std::static_pointer_cast>( + dev_ctx.GetBlob(key)); + + if (prim_creator == nullptr) { + prim_creator = + std::make_shared>(onednn_engine); + dev_ctx.SetBlob(key, prim_creator); + } + + return prim_creator; +} + +/* XT: input x data type, YT: input y data type */ +template +inner_product_forward GetMulPrimitive(const OneDNNContext &dev_ctx, + const DenseTensor *input_x, + const DenseTensor *input_y, + DenseTensor *output, + int x_num_col_dims, + int y_num_col_dims, + const engine &onednn_engine) { + constexpr bool is_int8 = funcs::is_int8(); + bool force_fp32_output = + dev_ctx.HasDnnAttr("force_fp32_output") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) + : false; + + if (is_int8 && !force_fp32_output) { + return GetPrimitiveFactory( + dev_ctx, input_x, input_y, onednn_engine) + ->CreateMulPrimitive( + input_x, input_y, output, x_num_col_dims, y_num_col_dims, dev_ctx); + + } else { + return GetPrimitiveFactory( + dev_ctx, input_x, input_y, onednn_engine) + ->CreateMulPrimitive( + input_x, input_y, output, x_num_col_dims, y_num_col_dims, dev_ctx); + } +} + +/* XT: input x data type */ +template +void MatmulWithFlattenKernelINT8(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor *out) { + PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == AllocationType::CPU, + true, + errors::PreconditionNotMet( + "oneDNN MatmulWithFlatten kernel must use CPUPlace")); + + OneDNNContext::tls().log_lib_version(); + auto &onednn_engine = dev_ctx.GetEngine(); + + auto out_dims = out->dims(); + + auto mul = GetMulPrimitive( + dev_ctx, &x, &y, out, x_num_col_dims, y_num_col_dims, onednn_engine); + + if (out_dims.size() != 2) { + out->Resize(out_dims); + } + + auto in_md = memory::desc(*dnnl_primitive_desc_query_md( + mul.get_primitive_desc(), dnnl_query_dst_md, 0)); + out->set_mem_desc(in_md.reshape(vectorize(out->dims()))); +} + +template +void MatmulWithFlattenKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor *out) { + constexpr bool is_int8 = funcs::is_int8(); + if (is_int8) { + MatmulWithFlattenKernelINT8( + dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); + return; + } + + const DenseTensor x_matrix = + x.dims().size() > 2 ? ReshapeToMatrix(x, x_num_col_dims) : x; + const DenseTensor y_matrix = + y.dims().size() > 2 ? ReshapeToMatrix(y, y_num_col_dims) : y; + + // adding mb dim because MatMulV2 handler needs it + std::vector x_dims(3, 1); + std::vector y_dims(3, 1); + + x_dims[1] = x_matrix.dims()[0]; + x_dims[2] = x_matrix.dims()[1]; + y_dims[1] = y_matrix.dims()[0]; + y_dims[2] = y_matrix.dims()[1]; + + funcs::ExecuteMul( + dev_ctx, x_matrix, y_matrix, x_dims, y_dims, false, false, out); +} + } // namespace phi PD_REGISTER_KERNEL(matmul, @@ -162,3 +583,12 @@ PD_REGISTER_KERNEL(matmul, phi::dtype::bfloat16, int8_t, uint8_t) {} + +PD_REGISTER_KERNEL(matmul_with_flatten, + OneDNN, + ONEDNN, + phi::MatmulWithFlattenKernel, + float, + phi::dtype::bfloat16, + uint8_t, + int8_t) {}