diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 35f93eba690e8ed1f898db663f082ecfefc690b8..723c3c8352d545e3f9c7f5014237fd487d4c32b1 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" +#include using dnnl::memory; using dnnl::primitive; @@ -20,6 +21,7 @@ using paddle::framework::DataLayout; using paddle::framework::ExecutionContext; using paddle::framework::vectorize; using paddle::platform::GetMKLDNNFormat; +using paddle::platform::MKLDNNFormatForSize; using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNGetDataType; using paddle::platform::to_void_cast; @@ -82,15 +84,39 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, } template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + +template +constexpr bool IsBfloat16() { + return std::is_same::value; +} + +// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the +// original x_dim is returned. +static paddle::framework::DDim RowMatrixDimsFromVector( + const paddle::framework::DDim& x_dim) { + return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]}); +} + +// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the +// original y_dim is returned. +static paddle::framework::DDim ColumnMatrixDimsFromVector( + const paddle::framework::DDim& y_dim) { + return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1}); +} + +template class MatMulMKLDNNHandler - : public paddle::platform::MKLDNNHandlerNoCachingT { + : public paddle::platform::MKLDNNHandlerNoCachingT { public: MatMulMKLDNNHandler(const mkldnn::engine engine, paddle::platform::Place cpu_place, Tensor* x, bool trans_x, Tensor* y, bool trans_y, Tensor* out, float scale) - : paddle::platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { + : paddle::platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x); auto mat_dim_y = @@ -115,117 +141,98 @@ class MatMulMKLDNNHandler !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; memory::dims out_strides = memory::dims{M * N, N, 1}; - auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); - auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); - auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_dims, MKLDNNGetDataType(), out_strides); dnnl::primitive_attr attrs; if (scale != 1.0f) attrs.set_output_scales(0, {scale}); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); } + // Constructor for FWD MatMul + MatMulMKLDNNHandler(const mkldnn::engine engine, const ExecutionContext& ctx, + float scale) + : paddle::platform::MKLDNNHandlerNoCachingT( + engine, ctx.GetPlace()), + matmul_dims_(GetMatmulDims(ctx)) { + dnnl::primitive_attr attr; + float scale_out = ComputeOutputScale(ctx); + if (scale_out != 1.0f) { + constexpr unsigned tensor_wide_scale = 0; + attr.set_output_scales(tensor_wide_scale, {scale_out}); + } + + auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType(), + matmul_dims_.x_strides); + auto y_md = memory::desc(matmul_dims_.y_dims, MKLDNNGetDataType(), + matmul_dims_.y_strides); + auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType(), + matmul_dims_.out_strides); + this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md); + } std::shared_ptr AcquireWeightsMemory(const Tensor* input) { - const T* input_data = input->data(); + const YT* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(input_data)); + to_void_cast(input_data)); } -}; -template -constexpr bool IsInt8() { - return std::is_same::value || std::is_same::value; -} - -template -constexpr bool IsBfloat16() { - return std::is_same::value; -} - -// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the -// original x_dim is returned. -static paddle::framework::DDim RowMatrixDimsFromVector( - const paddle::framework::DDim& x_dim) { - return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]}); -} - -// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the -// original y_dim is returned. -static paddle::framework::DDim ColumnMatrixDimsFromVector( - const paddle::framework::DDim& y_dim) { - return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1}); -} - -/** - * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. - * - * The shape would be [BatchSize, H, W] or [H, W]. - * If transposed, `H,W` will be swapped. - */ -static void ReshapeTensorToMatrixSequence( - Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) { - int64_t h, w; - h = descriptor.height_; - w = descriptor.width_; - if (descriptor.trans_) { - std::swap(w, h); - } - if (descriptor.batch_size_) { - x->Resize({descriptor.batch_size_, h, w}); - } else { - x->Resize({h, w}); - } -} + public: + void Execute(const paddle::framework::Tensor* x, + const paddle::framework::Tensor* y, + paddle::framework::Tensor* out) { + const auto src_memory_p = this->AcquireSrcMemory(x); + const auto weights_memory_p = this->AcquireWeightsMemory(y); + const auto dst_memory_p = this->AcquireDstMemory(out); + + auto matmul_p = this->AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + + // Simulate batch matmul by processing in loop + void* x_ptr = src_memory_p->get_data_handle(); + void* y_ptr = weights_memory_p->get_data_handle(); + void* out_ptr = dst_memory_p->get_data_handle(); + auto offsets = this->GetOffsets(); + for (uint16_t i = 0; i < this->GetBatchSize(); ++i) { + src_memory_p->set_data_handle(x_ptr); + weights_memory_p->set_data_handle(y_ptr); + dst_memory_p->set_data_handle(out_ptr); + matmul_p->execute(astream, { + {MKLDNN_ARG_SRC, *src_memory_p}, + {MKLDNN_ARG_WEIGHTS, *weights_memory_p}, + {MKLDNN_ARG_DST, *dst_memory_p}, + }); + x_ptr = static_cast(x_ptr) + std::get<0>(offsets); + y_ptr = static_cast(y_ptr) + std::get<1>(offsets); + out_ptr = static_cast(out_ptr) + std::get<2>(offsets); + } + astream.wait(); -/** - * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor - * Out = matmul(x, y) - * - * This method will first calculate X,Y matrix sequence, and then calculate - * the out shape. - * - * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] - * The out = [BatchSize, H1, W2] - * - * If there is no batch size in `X` and `Y`, the out will be [H1, W2] - * If any of `X` and `Y` has batch size BatchSize, the out will have the - * BatchSize. - */ -static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out, - bool trans_x, bool trans_y) { - auto x_dim = RowMatrixDimsFromVector(x->dims()); - auto y_dim = ColumnMatrixDimsFromVector(y->dims()); - auto mat_dim_x = - paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = - paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); - if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { - out->Resize({mat_dim_x.height_, mat_dim_y.width_}); - } else { - out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), - mat_dim_x.height_, mat_dim_y.width_}); + auto format = + MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); + out->set_format(format); + out->set_layout(DataLayout::kMKLDNN); } - ReshapeTensorToMatrixSequence(x, mat_dim_x); - ReshapeTensorToMatrixSequence(y, mat_dim_y); -} - -template -class MatMulFactory { - public: - void CreateAndExecute(const ExecutionContext& ctx) { - SetDNNLEngine(ctx); - if (IsInitialized()) { - UpdateDataPointers(ctx); - Execute(); - SetOutputFormat(ctx); - return; - } - CreateMemories(ctx); - CreatePrimitive(ctx); - Execute(); - SetOutputFormat(ctx); - SetInitialized(); + std::shared_ptr AcquireDstMemory( + paddle::framework::Tensor* output) { + // We cannot use base AcquireDstMemory as it makes an allocation request + // base on DST memory primitive size. This is fine in general, but in MatMul + // we have primitive that covers only one batch of Data and then shift + // pointer for every new batch. Hence Tensor size is bigger that dst memory + // primitive size. So would we request less memory that is there and it + // triggers an + // assertion. So as there is no 'any' format here we can leave default size + // of Tensor as computed in ComputeInferShape + OT* ptr = output->mutable_data(this->place_); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } private: @@ -234,47 +241,6 @@ class MatMulFactory { out_strides; }; - void SetDNNLEngine(const ExecutionContext& ctx) { - auto& dev_ctx = ctx.template device_context(); - engine_ = dev_ctx.GetEngine(); - } - - template - dnnl::memory CreateMemory(const memory::dims& dims, - const memory::dims& strides, const T* data) { - auto md = memory::desc(dims, MKLDNNGetDataType(), strides); - return dnnl::memory(md, engine_, to_void_cast(data)); - } - - std::vector Transpose(const std::vector& x, - const std::vector& axis) { - size_t in_rank = x.size(); - size_t axis_size = axis.size(); - - auto axis_set = std::set(axis.begin(), axis.end()); - PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, - paddle::platform::errors::InvalidArgument( - "In an axis array, elements must be unique.")); - - PADDLE_ENFORCE_EQ(in_rank, axis_size, - paddle::platform::errors::InvalidArgument( - "The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, axis_size)); - - PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, - paddle::platform::errors::InvalidArgument( - "Axis values must be ranging from 0 to (dims - 1).")); - - std::vector new_x(x.size()); - for (size_t i = 0; i < x.size(); i++) { - new_x[i] = x[axis[i]]; - } - return new_x; - } - std::pair GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); @@ -310,6 +276,15 @@ class MatMulFactory { return std::make_pair(mat_dim, strides); } + float ComputeOutputScale(const ExecutionContext& ctx) { + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + bool force_fp32_out = ctx.Attr("force_fp32_output"); + float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); + float alpha = ctx.Attr("alpha"); + return alpha * scale_out / (scale_x * scale_y); + } + bool IsInputFused(const ExecutionContext& ctx) const { return !(ctx.Attr>("fused_reshape_X").empty() && ctx.Attr>("fused_reshape_Y").empty()); @@ -322,14 +297,6 @@ class MatMulFactory { return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); } - void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx, - const memory::dim N, memory::dim b, - memory::dims* out_strides) const { - if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { - *out_strides = {N, b * N, 1}; - } - } - MatMulDims GetMatmulDims(const ExecutionContext& ctx) { paddle::operators::math::MatDescriptor mat_dim_x; memory::dims strides_x; @@ -381,125 +348,112 @@ class MatMulFactory { return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; } - void CreateMemories(const ExecutionContext& ctx) { - auto matmul_dims = GetMatmulDims(ctx); + std::vector Transpose(const std::vector& x, + const std::vector& axis) { + size_t in_rank = x.size(); + size_t axis_size = axis.size(); - x_mem_ = CreateMemory(matmul_dims.x_dims, matmul_dims.x_strides, - ctx.Input("X")->data()); - y_mem_ = CreateMemory(matmul_dims.y_dims, matmul_dims.y_strides, - ctx.Input("Y")->data()); - out_mem_ = CreateMemory( - matmul_dims.out_dims, matmul_dims.out_strides, - ctx.Output("Out")->mutable_data(ctx.GetPlace())); - } + auto axis_set = std::set(axis.begin(), axis.end()); + PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, + paddle::platform::errors::InvalidArgument( + "In an axis array, elements must be unique.")); - float ComputeOutputScale(const ExecutionContext& ctx) { - float scale_x = ctx.Attr("Scale_x"); - float scale_y = ctx.Attr("Scale_y"); - bool force_fp32_out = ctx.Attr("force_fp32_output"); - float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); - float alpha = ctx.Attr("alpha"); - return alpha * scale_out / (scale_x * scale_y); - } + PADDLE_ENFORCE_EQ(in_rank, axis_size, + paddle::platform::errors::InvalidArgument( + "The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, axis_size)); - void CreatePrimitive(const ExecutionContext& ctx) { - dnnl::primitive_attr attr; - float scale_out = ComputeOutputScale(ctx); - if (scale_out != 1.0f) { - constexpr unsigned tensor_wide_scale = 0; - attr.set_output_scales(tensor_wide_scale, {scale_out}); - } + PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, + paddle::platform::errors::InvalidArgument( + "Axis values must be ranging from 0 to (dims - 1).")); - auto matmul_d = dnnl::matmul::desc(x_mem_.get_desc(), y_mem_.get_desc(), - out_mem_.get_desc()); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine_); - matmul_prim_ = dnnl::matmul(matmul_pd); + std::vector new_x(x.size()); + for (size_t i = 0; i < x.size(); i++) { + new_x[i] = x[axis[i]]; + } + return new_x; } - void Execute() { - dnnl::stream stream(engine_); - - void* x_ptr = x_mem_.get_data_handle(); - void* y_ptr = y_mem_.get_data_handle(); - void* out_ptr = out_mem_.get_data_handle(); - for (uint16_t i = 0; i < batch_size_; i++) { - x_mem_.set_data_handle(x_ptr); - y_mem_.set_data_handle(y_ptr); - out_mem_.set_data_handle(out_ptr); - matmul_prim_.execute(stream, { - {MKLDNN_ARG_SRC, x_mem_}, - {MKLDNN_ARG_WEIGHTS, y_mem_}, - {MKLDNN_ARG_DST, out_mem_}, - }); - x_ptr = static_cast(x_ptr) + x_offset_; - y_ptr = static_cast(y_ptr) + y_offset_; - out_ptr = static_cast(out_ptr) + out_offset_; + void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx, + const memory::dim N, memory::dim b, + memory::dims* out_strides) const { + if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { + *out_strides = {N, b * N, 1}; } - stream.wait(); } - void SetOutputFormat(const ExecutionContext& ctx) { - using paddle::platform::MKLDNNFormatForSize; - auto* out = ctx.Output("Out"); - auto format = - MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_format(format); - out->set_layout(DataLayout::kMKLDNN); - } + uint16_t GetBatchSize(void) const { return batch_size_; } - void UpdateDataPointers(const ExecutionContext& ctx) { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - x_mem_.set_data_handle(to_void_cast(x->data())); - y_mem_.set_data_handle(to_void_cast(y->data())); - out_mem_.set_data_handle(out->mutable_data(ctx.GetPlace())); + std::tuple GetOffsets() const { + return std::make_tuple(x_offset_, y_offset_, out_offset_); } - // If initialized, x memory should've been already initialized - bool IsInitialized() { return initialized_; } - - void SetInitialized() { initialized_ = true; } - private: - struct memory_offsets { - size_t x_offset; - size_t y_offset; - size_t out_offset; - }; - - dnnl::engine engine_; - dnnl::memory x_mem_; - dnnl::memory y_mem_; - dnnl::memory out_mem_; - dnnl::matmul matmul_prim_; + MatMulDims matmul_dims_; uint32_t x_offset_; uint32_t y_offset_; uint32_t out_offset_; uint16_t batch_size_; - bool initialized_ = false; }; -template -static std::shared_ptr> GetPrimitiveFactory( - const ExecutionContext& ctx) { - const auto& out_name = ctx.OutputName("Out"); - const auto& dev_ctx = ctx.template device_context(); - const auto batch_size = ctx.Input("X")->dims()[0]; - std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name); - key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - - auto factory = - std::static_pointer_cast>(dev_ctx.GetBlob(key)); - if (factory == nullptr) { - factory = std::make_shared>(); - dev_ctx.SetBlob(key, factory); +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorToMatrixSequence( + Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +/** + * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor + * Out = matmul(x, y) + * + * This method will first calculate X,Y matrix sequence, and then calculate + * the out shape. + * + * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] + * The out = [BatchSize, H1, W2] + * + * If there is no batch size in `X` and `Y`, the out will be [H1, W2] + * If any of `X` and `Y` has batch size BatchSize, the out will have the + * BatchSize. + */ +static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out, + bool trans_x, bool trans_y) { + auto x_dim = RowMatrixDimsFromVector(x->dims()); + auto y_dim = ColumnMatrixDimsFromVector(y->dims()); + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); } - return factory; + ReshapeTensorToMatrixSequence(x, mat_dim_x); + ReshapeTensorToMatrixSequence(y, mat_dim_y); } -// Choose appropriate primitive factory implementation based on inferred +// Choose appropriate Handler instances based on inferred // output type (uint8, int8 or float). template static void ExecuteMatMul(const ExecutionContext& ctx) { @@ -507,31 +461,41 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { constexpr bool is_bfloat16 = IsBfloat16(); const bool force_fp32_output = ctx.Attr("force_fp32_output"); constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + const auto& dev_ctx = + ctx.template device_context(); + if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { - GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) + .Execute(x, y, out); } else if (is_bfloat16) { - GetPrimitiveFactory(ctx) - ->CreateAndExecute(ctx); + MatMulMKLDNNHandler(dev_ctx.GetEngine(), + ctx, alpha) + .Execute(x, y, out); } else if (fuse_relu) { - GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) + .Execute(x, y, out); } else { - GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) + .Execute(x, y, out); } } template -class DNNLMatMulKernel : public paddle::framework::OpKernel { +class MatMulMKLDNNKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext& ctx) const override { if (ctx.HasAttr("head_number")) { PADDLE_ENFORCE_EQ( ctx.Attr("head_number"), 1, paddle::platform::errors::Unimplemented( - "DNNL matmul doesn't support multiple heads. Expected " + "oneDNN matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", ctx.Attr("head_number"))); } - MKLDNNDeviceContext::tls().log_lib_version(); ExecuteMatMul(ctx); } }; @@ -547,7 +511,7 @@ void MatMulGradMKLDNNKernel::Compute(const ExecutionContext& ctx) const { PADDLE_ENFORCE_EQ( ctx.Attr("head_number"), 1, platform::errors::Unimplemented( - "DNNL matmul doesn't support multiple heads. Expected " + "oneDNN matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", ctx.Attr("head_number"))); } @@ -577,8 +541,9 @@ void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - MatMulMKLDNNHandler handler(engine, ctx.GetPlace(), &x_combined, trans_x, - &y_combined, trans_y, out, alpha); + MatMulMKLDNNHandler handler(engine, ctx.GetPlace(), &x_combined, + trans_x, &y_combined, trans_y, out, + alpha); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); @@ -679,9 +644,9 @@ template class MatMulGradMKLDNNKernel; namespace ops = paddle::operators; REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, - DNNLMatMulKernel, - DNNLMatMulKernel, - DNNLMatMulKernel, DNNLMatMulKernel); + MatMulMKLDNNKernel, + MatMulMKLDNNKernel, + MatMulMKLDNNKernel, MatMulMKLDNNKernel); REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::MatMulGradMKLDNNKernel,