From 0be7157185138a5e92732225f56737d34fc532f8 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Fri, 30 Jul 2021 17:03:59 +0200 Subject: [PATCH] Added matmul_v2 BF16/FP32 BWD kernel (#34192) * test version of matmul_v2 * added matmul_v2 grad kernel * minor changes * minor changes * minor change for CI approval * CI fix * CI fix * trigger CI * changes after review, not working yet * moved ops to anonymous namespaces * changes after review --- paddle/fluid/operators/matmul_v2_op.cc | 24 +- .../operators/mkldnn/matmul_mkldnn_op.cc | 420 +++++++++--------- .../fluid/operators/mkldnn/matmul_mkldnn_op.h | 42 ++ .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 266 +++++++++-- .../mkldnn/test_matmul_v2_mkldnn_op.py | 119 ++++- 5 files changed, 598 insertions(+), 273 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 8ac81596a3..d39eac0759 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel { } std::vector new_dims; - if (ndims_x >= ndims_y) { + if (ndims_x > ndims_y) { new_dims.assign(dims_x.begin(), dims_x.end() - 2); - } else { + } else if (ndims_x < ndims_y) { new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } else { + new_dims.reserve(ndims_x); + for (size_t i = 0; i < ndims_x - 2; ++i) { + new_dims.push_back(std::max(dims_x[i], dims_y[i])); + } } if (!x_broadcasted) { new_dims.push_back(M); @@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name), - ctx.GetPlace()); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 2b3496359b..7ebd3e3856 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,34 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace platform { -class MKLDNNDeviceContext; -struct CPUPlace; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { +#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" using dnnl::memory; using dnnl::primitive; -using framework::DataLayout; -using framework::ExecutionContext; -using platform::GetMKLDNNFormat; -using platform::MKLDNNDeviceContext; -using platform::MKLDNNGetDataType; -using platform::to_void_cast; -using Tensor = framework::Tensor; +using paddle::framework::DataLayout; +using paddle::framework::ExecutionContext; +using paddle::framework::vectorize; +using paddle::platform::GetMKLDNNFormat; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::to_void_cast; +using Tensor = paddle::framework::Tensor; + +namespace { // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. -static framework::Tensor FoldOuterDims(const Tensor& input) { +static Tensor FoldOuterDims(const Tensor& input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { @@ -52,36 +42,38 @@ static framework::Tensor FoldOuterDims(const Tensor& input) { // (Warning: This requires transposing data and writes into new memory.) // Identity op if the tensor is not of rank 3. template -static framework::Tensor FoldFirstAndLastDims( - const MKLDNNDeviceContext& dev_ctx, const Tensor* input) { - auto input_dims = framework::vectorize(input->dims()); +static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, + const Tensor* input) { + auto input_dims = vectorize(input->dims()); if (input_dims.size() != 3) { return *input; } - framework::Tensor output; + Tensor output; output.Resize({input_dims[1], input_dims[0], input_dims[2]}); - auto output_dims = framework::vectorize(output.dims()); + auto output_dims = vectorize(output.dims()); - memory::data_type input_type = framework::ToMKLDNNDataType(input->type()); - std::string key = platform::CreateKey(dev_ctx, input_dims, input->format(), - input->format(), input_type); - platform::ReorderMKLDNNHandler reorder_handler(output_dims, input->type(), - input_type, dev_ctx, - dev_ctx.GetEngine(), key); + memory::data_type input_type = + paddle::framework::ToMKLDNNDataType(input->type()); + std::string key = paddle::platform::CreateKey( + dev_ctx, input_dims, input->format(), input->format(), input_type); + paddle::platform::ReorderMKLDNNHandler reorder_handler( + output_dims, input->type(), input_type, dev_ctx, dev_ctx.GetEngine(), + key); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - memory::format_tag::abc, platform::to_void_cast(input->data())); + memory::format_tag::abc, + paddle::platform::to_void_cast(input->data())); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( &output, memory::format_tag::bac, dev_ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, reorder_dst_memory_p); - platform::RecordEvent record_reorder("int_reorder", - platform::EventRole::kUniqueOp); + paddle::platform::RecordEvent record_reorder( + "int_reorder", paddle::platform::EventRole::kUniqueOp); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + auto& astream = MKLDNNDeviceContext::tls().get_stream(); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); @@ -90,19 +82,23 @@ static framework::Tensor FoldFirstAndLastDims( } template -class MatMulMKLDNNHandler : public platform::MKLDNNHandlerT { +class MatMulMKLDNNHandler + : public paddle::platform::MKLDNNHandlerT { public: MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, platform::Place cpu_place, - Tensor* x, bool trans_x, Tensor* y, bool trans_y, - Tensor* out, float scale, const std::string& uniq_name) - : platform::MKLDNNHandlerT( + const mkldnn::engine engine, + paddle::platform::Place cpu_place, Tensor* x, + bool trans_x, Tensor* y, bool trans_y, Tensor* out, + float scale, const std::string& uniq_name) + : paddle::platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), - uniq_name)) { + paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()), + uniq_name)) { if (!this->isCached()) { - auto mat_dim_x = math::CreateMatrixDescriptor(x->dims(), 0, trans_x); - auto mat_dim_y = math::CreateMatrixDescriptor(y->dims(), 0, trans_y); + auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor( + x->dims(), 0, trans_x); + auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor( + y->dims(), 0, trans_y); memory::dim x_bs = mat_dim_x.batch_size_; memory::dim y_bs = mat_dim_y.batch_size_; @@ -149,20 +145,21 @@ constexpr bool IsInt8() { template constexpr bool IsBfloat16() { - return std::is_same::value; + return std::is_same::value; } // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // original x_dim is returned. -static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { - return x_dim.size() > 1 ? x_dim : framework::make_ddim({1, x_dim[0]}); +static paddle::framework::DDim RowMatrixDimsFromVector( + const paddle::framework::DDim& x_dim) { + return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]}); } // Get column matrix shape from a vector shape. If the ran of y_dim > 1, the // original y_dim is returned. -static framework::DDim ColumnMatrixDimsFromVector( - const framework::DDim& y_dim) { - return y_dim.size() > 1 ? y_dim : framework::make_ddim({y_dim[0], 1}); +static paddle::framework::DDim ColumnMatrixDimsFromVector( + const paddle::framework::DDim& y_dim) { + return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1}); } /** @@ -172,7 +169,7 @@ static framework::DDim ColumnMatrixDimsFromVector( * If transposed, `H,W` will be swapped. */ static void ReshapeTensorToMatrixSequence( - framework::Tensor* x, const math::MatDescriptor& descriptor) { + Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -200,14 +197,14 @@ static void ReshapeTensorToMatrixSequence( * If any of `X` and `Y` has batch size BatchSize, the out will have the * BatchSize. */ -static void ReshapeXYOutToMatrixSequence(framework::Tensor* x, - framework::Tensor* y, - framework::Tensor* out, bool trans_x, - bool trans_y) { +static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out, + bool trans_x, bool trans_y) { auto x_dim = RowMatrixDimsFromVector(x->dims()); auto y_dim = ColumnMatrixDimsFromVector(y->dims()); - auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); - auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { out->Resize({mat_dim_x.height_, mat_dim_y.width_}); } else { @@ -244,8 +241,7 @@ class MatMulFactory { }; void SetDNNLEngine(const ExecutionContext& ctx) { - auto& dev_ctx = - ctx.template device_context(); + auto& dev_ctx = ctx.template device_context(); engine_ = dev_ctx.GetEngine(); } @@ -263,19 +259,19 @@ class MatMulFactory { auto axis_set = std::set(axis.begin(), axis.end()); PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, - platform::errors::InvalidArgument( + paddle::platform::errors::InvalidArgument( "In an axis array, elements must be unique.")); - PADDLE_ENFORCE_EQ( - in_rank, axis_size, - platform::errors::InvalidArgument("The input dimension's size " - "should be equal to the axis's size. " - "But received dimension is %d, " - "axis's size is %d", - in_rank, axis_size)); + PADDLE_ENFORCE_EQ(in_rank, axis_size, + paddle::platform::errors::InvalidArgument( + "The input dimension's size " + "should be equal to the axis's size. " + "But received dimension is %d, " + "axis's size is %d", + in_rank, axis_size)); PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, - platform::errors::InvalidArgument( + paddle::platform::errors::InvalidArgument( "Axis values must be ranging from 0 to (dims - 1).")); std::vector new_x(x.size()); @@ -285,8 +281,8 @@ class MatMulFactory { return new_x; } - std::pair GetInputDimsAndStrides( - const ExecutionContext& ctx, std::string input_name) { + std::pair + GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) { auto shape = ctx.Attr>("fused_reshape_" + input_name); auto axis = ctx.Attr>("fused_transpose_" + input_name); auto input_dims = ctx.Input(input_name)->dims(); @@ -297,9 +293,10 @@ class MatMulFactory { auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector; - math::MatDescriptor mat_dim = - math::CreateMatrixDescriptor(MatrixDimsFromVector(new_dims), 0, - ctx.Attr("transpose_" + input_name)); + paddle::operators::math::MatDescriptor mat_dim = + paddle::operators::math::CreateMatrixDescriptor( + MatrixDimsFromVector(new_dims), 0, + ctx.Attr("transpose_" + input_name)); memory::dims strides; if (!shape.empty()) { @@ -340,17 +337,17 @@ class MatMulFactory { } MatMulDims GetMatmulDims(const ExecutionContext& ctx) { - math::MatDescriptor mat_dim_x; + paddle::operators::math::MatDescriptor mat_dim_x; memory::dims strides_x; std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); - math::MatDescriptor mat_dim_y; + paddle::operators::math::MatDescriptor mat_dim_y; memory::dims strides_y; std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); auto x_bs = mat_dim_x.batch_size_; auto y_bs = mat_dim_y.batch_size_; PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, - platform::errors::InvalidArgument( + paddle::platform::errors::InvalidArgument( "If batch sizes of X and Y are positive," "they have to be equal.")); @@ -448,10 +445,10 @@ class MatMulFactory { } void SetOutputFormat(const ExecutionContext& ctx) { - using platform::MKLDNNFormatForSize; + using paddle::platform::MKLDNNFormatForSize; auto* out = ctx.Output("Out"); auto format = - MKLDNNFormatForSize(out->dims().size(), MKLDNNMemoryFormat::nchw); + MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); out->set_format(format); out->set_layout(DataLayout::kMKLDNN); } @@ -495,8 +492,8 @@ static std::shared_ptr> GetPrimitiveFactory( const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); const auto batch_size = ctx.Input("X")->dims()[0]; - std::string key = platform::CreateKey(dev_ctx, batch_size, out_name); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name); + key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); @@ -529,161 +526,170 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { } template -class DNNLMatMulKernel : public framework::OpKernel { +class DNNLMatMulKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext& ctx) const override { if (ctx.HasAttr("head_number")) { PADDLE_ENFORCE_EQ( ctx.Attr("head_number"), 1, - platform::errors::Unimplemented( + paddle::platform::errors::Unimplemented( "DNNL matmul doesn't support multiple heads. Expected " "head_number=1. But received `head_number` is %d", ctx.Attr("head_number"))); } - platform::MKLDNNDeviceContext::tls().log_lib_version(); + MKLDNNDeviceContext::tls().log_lib_version(); ExecuteMatMul(ctx); } }; +} // anonymous namespace + +namespace paddle { +namespace operators { + template -class MatMulGradMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const ExecutionContext& ctx) const override { - if (ctx.HasAttr("head_number")) { - PADDLE_ENFORCE_EQ( - ctx.Attr("head_number"), 1, - platform::errors::Unimplemented( - "DNNL matmul doesn't support multiple heads. Expected " - "head_number=1. But received `head_number` is %d", - ctx.Attr("head_number"))); - } - RunKernel(ctx); +void MatMulGradMKLDNNKernel::Compute(const ExecutionContext& ctx) const { + if (ctx.HasAttr("head_number")) { + PADDLE_ENFORCE_EQ( + ctx.Attr("head_number"), 1, + platform::errors::Unimplemented( + "DNNL matmul doesn't support multiple heads. Expected " + "head_number=1. But received `head_number` is %d", + ctx.Attr("head_number"))); } + RunKernel(ctx); +} - private: - void ExecuteMatMulGrad(const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine& engine, Tensor* x, bool trans_x, - bool is_fold_init_dims_x, Tensor* y, bool trans_y, - bool is_fold_init_dims_y, Tensor* out, - int execution_number) const { - // gradient is calculated in a different way when broadcasting is used - bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && - out->dims().size() == 2; - - Tensor x_combined, y_combined; - if (!need_combine) { - x_combined = *x; - y_combined = *y; - } else { - x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) - : FoldFirstAndLastDims(dev_ctx, x); - y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) - : FoldFirstAndLastDims(dev_ctx, y); - } +template +void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( + const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine& engine, Tensor* x, bool trans_x, + bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y, + Tensor* out, int execution_number) const { + // gradient is calculated in a different way when broadcasting is used + bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && + out->dims().size() == 2; + + Tensor x_combined, y_combined; + if (!need_combine) { + x_combined = *x; + y_combined = *y; + } else { + x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) + : FoldFirstAndLastDims(dev_ctx, x); + y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) + : FoldFirstAndLastDims(dev_ctx, y); + } - MatMulMKLDNNHandler handler( - dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined, - trans_y, out, ctx.Attr("alpha"), - ctx.InputName(framework::GradVarName("Out")) + - std::to_string(execution_number)); - - const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); - const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); - const auto dst_memory_p = handler.AcquireDstMemory(out); - - auto matmul_p = handler.AcquireForwardPrimitive(); - - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); - astream.wait(); - - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( - framework::vectorize(out->dims())))); - } - - template - void RunKernel(const ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& onednn_engine = dev_ctx.GetEngine(); - - auto x = *ctx.Input("X"); - auto y = *ctx.Input("Y"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - bool transpose_x = ctx.Attr("transpose_X"); - bool transpose_y = ctx.Attr("transpose_Y"); - - ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } + MatMulMKLDNNHandler handler(dev_ctx, engine, ctx.GetPlace(), &x_combined, + trans_x, &y_combined, trans_y, out, alpha, + ctx.InputName(framework::GradVarName("Out")) + + std::to_string(execution_number)); + + const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); + const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); + const auto dst_memory_p = handler.AcquireDstMemory(out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(vectorize(out->dims())))); +} + +template +void MatMulGradMKLDNNKernel::RunKernel(const ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto x = *ctx.Input("X"); + auto y = *ctx.Input("Y"); + auto dout = *ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr("transpose_X") + : ctx.Attr("trans_x"); + bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr("transpose_Y") + : ctx.Attr("trans_y"); + + ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); } + } - if (transpose_x && transpose_y) { - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, - &dout, true, false, dx, 0); - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, - &x, true, false, dy, 1); - } else if (transpose_x) { - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, - &dout, true, false, dx, 0); - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false, - &dout, false, true, dy, 1); - } else if (transpose_y) { - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, - &y, false, true, dx, 0); - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, - &x, false, true, dy, 1); - } else { - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, - &y, true, false, dx, 0); - this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, - &dout, false, true, dy, 1); + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); } + } - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } + if (transpose_x && transpose_y) { + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout, + true, false, dx, 0); + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, + true, false, dy, 1); + } else if (transpose_x) { + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, + &dout, true, false, dx, 0); + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false, + &dout, false, true, dy, 1); + } else if (transpose_y) { + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, + &y, false, true, dx, 0); + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x, + false, true, dy, 1); + } else { + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, + &y, true, false, dx, 0); + this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout, + false, true, dy, 1); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + dx->set_format(x.format()); } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + dy->set_format(y.format()); } } -}; +} + +template class MatMulGradMKLDNNKernel; +template class MatMulGradMKLDNNKernel; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, - ops::DNNLMatMulKernel, - ops::DNNLMatMulKernel, - ops::DNNLMatMulKernel, - ops::DNNLMatMulKernel); + DNNLMatMulKernel, + DNNLMatMulKernel, + DNNLMatMulKernel, DNNLMatMulKernel); REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::MatMulGradMKLDNNKernel, diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h new file mode 100644 index 0000000000..725d1fff9c --- /dev/null +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using platform::MKLDNNDeviceContext; +using framework::ExecutionContext; +using Tensor = framework::Tensor; + +template +class MatMulGradMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const ExecutionContext& ctx) const override; + + private: + void ExecuteMatMulGrad(const ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine& engine, Tensor* x, bool trans_x, + bool is_fold_init_dims_x, Tensor* y, bool trans_y, + bool is_fold_init_dims_y, Tensor* out, + int execution_number) const; + void RunKernel(const ExecutionContext& ctx) const; +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 50afd41717..b5dc096441 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" -namespace paddle { -namespace operators { +namespace { using dnnl::memory; using dnnl::primitive; -using framework::DataLayout; -using framework::ExecutionContext; -using platform::GetMKLDNNFormat; -using platform::MKLDNNDeviceContext; -using platform::MKLDNNGetDataType; -using platform::to_void_cast; -using Tensor = framework::Tensor; +using paddle::framework::DataLayout; +using paddle::framework::ExecutionContext; +using paddle::platform::GetMKLDNNFormat; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::to_void_cast; +using Tensor = paddle::framework::Tensor; +using paddle::framework::vectorize; +using paddle::framework::make_ddim; +using paddle::framework::GradVarName; template -class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT { +class MatMulV2MKLDNNHandler + : public paddle::platform::MKLDNNHandlerT { public: MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, platform::Place cpu_place, - std::vector& x_dims, bool trans_x, - std::vector& y_dims, bool trans_y, + const mkldnn::engine engine, + paddle::platform::Place cpu_place, + const std::vector& x_org_dims, bool trans_x, + const std::vector& y_org_dims, bool trans_y, const std::string& uniq_name) - : platform::MKLDNNHandlerT( + : paddle::platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, x_dims, uniq_name)) { + paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) { if (!this->isCached()) { // M X K * K X N + std::vector x_dims(x_org_dims); + std::vector y_dims(y_org_dims); + const int MB_idx = x_dims.size() - 3; const int H_idx = x_dims.size() - 2; const int W_idx = x_dims.size() - 1; @@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT { }; template -class MatMulV2MKLDNNKernel : public framework::OpKernel { +class MatMulV2MKLDNNKernel + : public paddle::operators::MatMulGradMKLDNNKernel { public: void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } + protected: + void ExecuteMatMul(const ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine onednn_engine, + paddle::platform::Place cpu_place, const Tensor* x, + std::vector& x_dims, bool trans_x, + const Tensor* y, std::vector& y_dims, + bool trans_y, Tensor* out, std::vector& out_dims, + int execution_number = 0) const { + MatMulV2MKLDNNHandler handler( + dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims, + trans_y, ctx.InputName("X") + std::to_string(execution_number)); + + const auto src_memory_p = handler.AcquireSrcMemory(x); + const auto weights_memory_p = handler.AcquireWeightsMemory(y); + const auto dst_memory_p = handler.AcquireDstMemory(out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + out->set_layout(paddle::framework::DataLayout::kMKLDNN); + out->set_format( + GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); + } + private: void CalculateMatrixDims(const ExecutionContext& ctx, const std::vector& x_dims, @@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { std::vector& out_dims, Tensor* out) const { if (x_dims.size() == 1) { x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; + } else if (x_dims.size() == 2) { + x_bd_dims[2] = x_dims[1]; + x_bd_dims[1] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { x_bd_dims[i] = x_dims[i]; @@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { } if (y_dims.size() == 1) { y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; + } else if (y_dims.size() == 2) { + y_bd_dims[2] = y_dims[1]; + y_bd_dims[1] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { y_bd_dims[i] = y_dims[i]; @@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { for (size_t i = 0; i < x_dims.size() - 2; ++i) { PADDLE_ENFORCE_EQ( x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, - platform::errors::InvalidArgument( + paddle::platform::errors::InvalidArgument( "Tensor dimensions are incorrect for broadcasting." "Dimensions in X and Y must be same or equal to 1, but " "received x_dim[%d]=%d and y_dims[%d]= %d", i, x_dims[i], i, y_dims[i])); out_dims[i] = std::max(x_dims[i], y_dims[i]); } - out->Resize(framework::make_ddim(out_dims)); + out->Resize(make_ddim(out_dims)); } } @@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - auto x_dims = framework::vectorize(x->dims()); - auto y_dims = framework::vectorize(y->dims()); - auto out_dims = framework::vectorize(out->dims()); + auto x_dims = vectorize(x->dims()); + auto y_dims = vectorize(y->dims()); + auto out_dims = vectorize(out->dims()); int ndims = std::max(x->dims().size(), y->dims().size()); ndims = std::max(ndims, 3); @@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel { CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, out); - MatMulV2MKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), - x_bd_dims, trans_x, y_bd_dims, trans_y, - ctx.InputName("X")); + ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims, + trans_x, y, y_bd_dims, trans_y, out, out_dims); + } +}; - const auto src_memory_p = handler.AcquireSrcMemory(x); - const auto weights_memory_p = handler.AcquireWeightsMemory(y); - const auto dst_memory_p = handler.AcquireDstMemory(out); +template +class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { + public: + void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } - auto matmul_p = handler.AcquireForwardPrimitive(); + private: + void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp, + Tensor* dy_tmp, + const std::vector& dx_dims, + const std::vector& dy_dims, + std::vector& dx_bd_dims, + std::vector& dy_bd_dims) const { + for (size_t i = 0; i < dx_dims.size() - 2; ++i) { + if (dx_dims[i] != dy_dims[i]) { + if (dx_dims[i] == 1) { + dx_bd_dims[i] = dy_dims[i]; + } else { + dy_bd_dims[i] = dx_dims[i]; + } + } + } - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; + dx_tmp->Resize(make_ddim(dx_bd_dims)); + dx_tmp->mutable_data(ctx.GetPlace()); + dy_tmp->Resize(make_ddim(dy_bd_dims)); + dy_tmp->mutable_data(ctx.GetPlace()); + } + + void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine onednn_engine, + const Tensor* dx_tmp, Tensor* dx, + std::vector dx_dims) const { + paddle::platform::ReductionMKLDNNHandler handler( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, + ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims); + + auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); + auto dst_memory_p = handler.AcquireDstMemory(dx); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; auto& astream = MKLDNNDeviceContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); astream.wait(); + } - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format( - GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); + void RunKernel(const ExecutionContext& ctx) const { + const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + + auto x_dims = vectorize(x->dims()); + auto y_dims = vectorize(y->dims()); + + bool is_broadcast = true; + if (x_dims.size() <= 2 || y_dims.size() <= 2) { + is_broadcast = false; + } else if (x_dims.size() != y_dims.size()) { + is_broadcast = true; + } else { + is_broadcast = + !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2, + y_dims.cbegin()); + } + + // if no broadcasting is needed, we can simply use matmul's grad and avoid + // using reduce_sum + if (!is_broadcast) { + paddle::operators::MatMulGradMKLDNNKernel::Compute(ctx); + return; + } + + auto* dout = ctx.Input(GradVarName("Out")); + auto* dx = ctx.Output(GradVarName("X")); + auto* dy = ctx.Output(GradVarName("Y")); + + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + auto dout_dims = vectorize(dout->dims()); + + int ndims = std::max(x->dims().size(), y->dims().size()); + ndims = std::max(ndims, 3); + + // in broadcasting scenario new memory is required because + // reduce sum must be calculated upon broadcasted dims + Tensor dx_tmp, dy_tmp; + + std::vector dx_bd_dims(x_dims); + std::vector dy_bd_dims(y_dims); + + CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims, + dy_bd_dims); + + if (trans_x && trans_y) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, + y_dims, true, dout, dout_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, + 2); + } else if (trans_x) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, + y_dims, false, dout, dout_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, + x_dims, false, dout, dout_dims, false, &dy_tmp, + dy_bd_dims, 2); + } else if (trans_y) { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, false, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, false, &dy_tmp, + dy_bd_dims, 2); + } else { + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, true, &dx_tmp, + dx_bd_dims, 1); + this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, + x_dims, true, dout, dout_dims, false, &dy_tmp, + dy_bd_dims, 2); + } + + if (x_dims != dx_bd_dims) { + ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, + x_dims); + } else { + *dx = std::move(dx_tmp); + } + if (y_dims != dy_bd_dims) { + ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, + y_dims); + } else { + *dy = std::move(dy_tmp); + } + + dx->set_layout(paddle::framework::DataLayout::kMKLDNN); + dx->set_format(x->format()); + dy->set_layout(paddle::framework::DataLayout::kMKLDNN); + dy->set_format(y->format()); } }; -} // namespace operators -} // namespace paddle +} // anonymous namespace namespace ops = paddle::operators; REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, - ops::MatMulV2MKLDNNKernel, - ops::MatMulV2MKLDNNKernel); + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel); -// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace, -// ops::MatMulV2GradMKLDNNKernel, -// ops::MatMulV2GradMKLDNNKernel); +REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace, + MatMulV2GradMKLDNNKernel, + MatMulV2GradMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index ea06e2c447..5cc6651bb0 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +from functools import reduce import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 @@ -23,14 +24,12 @@ import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework -paddle.enable_static() - -def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): +def reference_matmul(X, Y, transpose_x=False, transpose_y=False): """Reference forward implementation using np.matmul.""" # np.matmul does not support the transpose flags, so we manually # transpose X and Y appropriately. - if transpose_X: + if transpose_x: if X.ndim == 1: X = X.reshape((X.size, )) elif X.ndim == 2: @@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): dim = [i for i in range(len(X.shape))] dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] X = np.transpose(X, tuple(dim)) - if transpose_Y: + if transpose_y: if Y.ndim == 1: Y = Y.reshape((Y.size, )) else: @@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp( class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (1, 1, 12, 4) - self.y_shape = (1, 2, 4, 12) + self.x_shape = (2, 1, 12, 9) + self.y_shape = (1, 3, 9, 12) self.trans_x = False self.trans_y = False @@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2( class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (2, 2, 5, 4) - self.y_shape = (2, 2, 5, 3) + self.x_shape = (2, 2, 7, 4) + self.y_shape = (2, 2, 7, 5) self.trans_x = True self.trans_y = False @@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp( TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (3, 1, 6, 5) + self.x_shape = (3, 1, 6, 7) self.y_shape = (1, 2, 6, 9) self.trans_x = True self.trans_y = False @@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): def config(self): - self.x_shape = (2, 1, 40) - self.y_shape = (40) + self.x_shape = (2, 1, 100) + self.y_shape = (100) self.trans_x = False self.trans_y = False @@ -245,6 +244,8 @@ def create_bf16_test_class(parent): 'X': convert_float_to_uint16(x), 'Y': convert_float_to_uint16(y) } + self.x_fp32 = x + self.y_fp32 = y def set_dtype_attr(self): self.attrs['mkldnn_data_type'] = "bfloat16" @@ -253,7 +254,99 @@ def create_bf16_test_class(parent): self.check_output_with_place(core.CPUPlace()) def test_check_grad(self): - pass + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + user_defined_grads=[self.dx, self.dy], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + def matmul_grad(self, x, transpose_x, y, transpose_y): + x = np.transpose( + x, self.shape_transpose_axes[x.ndim]) if transpose_x else x + y = np.transpose( + y, self.shape_transpose_axes[y.ndim]) if transpose_y else y + + return np.matmul(x, y) + + def calculate_grads(self): + self.shape_transpose_axes = { + 2: [1, 0], + 3: [0, 2, 1], + 4: [0, 1, 3, 2], + 5: [0, 1, 2, 4, 3] + } + + # expand vector so it will be a valid matrix for multiplication + if self.x_fp32.ndim == 1: + self.x_fp32 = np.expand_dims(self.x_fp32, axis=0) + if self.y_fp32.ndim == 1: + self.y_fp32 = np.expand_dims(self.y_fp32, axis=1) + + x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim] + y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim] + + x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[ + 'trans_x'] is True else self.x_fp32 + y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[ + 'trans_y'] is True else self.y_fp32 + + dout = np.matmul(x, y) + + x_shape = x.shape + y_shape = y.shape + + if x.ndim <= 2 or y.ndim <= 2: + is_broadcast = False + elif x.ndim != y.ndim: + is_broadcast = True + else: + is_broadcast = x.shape[0:-2] != y.shape[0:-2] + + if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True: + self.dx = self.matmul_grad(self.y_fp32, True, dout, True) + self.dy = self.matmul_grad(dout, True, self.x_fp32, True) + elif self.attrs['trans_x'] is True and self.attrs[ + 'trans_y'] is False: + self.dx = self.matmul_grad(self.y_fp32, False, dout, True) + self.dy = self.matmul_grad(self.x_fp32, False, dout, False) + elif self.attrs['trans_x'] is False and self.attrs[ + 'trans_y'] is True: + self.dx = self.matmul_grad(dout, False, self.y_fp32, False) + self.dy = self.matmul_grad(dout, True, self.x_fp32, False) + else: + self.dx = self.matmul_grad(dout, False, self.y_fp32, True) + self.dy = self.matmul_grad(self.x_fp32, True, dout, False) + + if is_broadcast: + x_reduce_axis = [] + y_reduce_axis = [] + for index, ( + first, second + ) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])): + if first != second: + x_reduce_axis.append(index) + + for index, ( + first, second + ) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])): + if first != second: + y_reduce_axis.append(index) + + if x_reduce_axis: + self.dx = self.dx.sum(axis=tuple(x_reduce_axis), + keepdims=True) + if y_reduce_axis: + self.dy = self.dy.sum(axis=tuple(y_reduce_axis), + keepdims=True) + + # after multiplying with vector one dimension is deleted from tensor + if len(x_shape) == 2 and x_shape[0] == 1: + dout = dout.sum(axis=-2) + if len(y_shape) == 2 and y_shape[1] == 1: + dout = dout.sum(axis=-1) + + self.dout = dout cls_name = "{0}_{1}".format(parent.__name__, "BF16") TestMatMulV2Bf16OneDNNOp.__name__ = cls_name -- GitLab