未验证 提交 0be71571 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2 BF16/FP32 BWD kernel (#34192)

* test version of matmul_v2

* added matmul_v2 grad kernel

* minor changes

* minor changes

* minor change for CI approval

* CI fix

* CI fix

* trigger CI

* changes after review, not working yet

* moved ops to anonymous namespaces

* changes after review
上级 44e4d57b
...@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel {
} }
std::vector<int64_t> new_dims; std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) { if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2); new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else { } else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2); new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
} }
if (!x_broadcasted) { if (!x_broadcasted) {
new_dims.push_back(M); new_dims.push_back(M);
...@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(
return framework::OpKernelType( ctx, framework::GradVarName("Out"));
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace()); #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,34 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,34 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace platform {
class MKLDNNDeviceContext;
struct CPUPlace;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using framework::DataLayout; using paddle::framework::DataLayout;
using framework::ExecutionContext; using paddle::framework::ExecutionContext;
using platform::GetMKLDNNFormat; using paddle::framework::vectorize;
using platform::MKLDNNDeviceContext; using paddle::platform::GetMKLDNNFormat;
using platform::MKLDNNGetDataType; using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast; using paddle::platform::MKLDNNGetDataType;
using Tensor = framework::Tensor; using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
namespace {
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
static framework::Tensor FoldOuterDims(const Tensor& input) { static Tensor FoldOuterDims(const Tensor& input) {
auto output = input; auto output = input;
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() == 3) {
...@@ -52,36 +42,38 @@ static framework::Tensor FoldOuterDims(const Tensor& input) { ...@@ -52,36 +42,38 @@ static framework::Tensor FoldOuterDims(const Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.) // (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename T> template <typename T>
static framework::Tensor FoldFirstAndLastDims( static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
const MKLDNNDeviceContext& dev_ctx, const Tensor* input) { const Tensor* input) {
auto input_dims = framework::vectorize(input->dims()); auto input_dims = vectorize(input->dims());
if (input_dims.size() != 3) { if (input_dims.size() != 3) {
return *input; return *input;
} }
framework::Tensor output; Tensor output;
output.Resize({input_dims[1], input_dims[0], input_dims[2]}); output.Resize({input_dims[1], input_dims[0], input_dims[2]});
auto output_dims = framework::vectorize(output.dims()); auto output_dims = vectorize(output.dims());
memory::data_type input_type = framework::ToMKLDNNDataType(input->type()); memory::data_type input_type =
std::string key = platform::CreateKey(dev_ctx, input_dims, input->format(), paddle::framework::ToMKLDNNDataType(input->type());
input->format(), input_type); std::string key = paddle::platform::CreateKey(
platform::ReorderMKLDNNHandler reorder_handler(output_dims, input->type(), dev_ctx, input_dims, input->format(), input->format(), input_type);
input_type, dev_ctx, paddle::platform::ReorderMKLDNNHandler reorder_handler(
dev_ctx.GetEngine(), key); output_dims, input->type(), input_type, dev_ctx, dev_ctx.GetEngine(),
key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
memory::format_tag::abc, platform::to_void_cast(input->data<T>())); memory::format_tag::abc,
paddle::platform::to_void_cast(input->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
&output, memory::format_tag::bac, dev_ctx.GetPlace()); &output, memory::format_tag::bac, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p); reorder_dst_memory_p);
platform::RecordEvent record_reorder("int_reorder", paddle::platform::RecordEvent record_reorder(
platform::EventRole::kUniqueOp); "int_reorder", paddle::platform::EventRole::kUniqueOp);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
...@@ -90,19 +82,23 @@ static framework::Tensor FoldFirstAndLastDims( ...@@ -90,19 +82,23 @@ static framework::Tensor FoldFirstAndLastDims(
} }
template <typename T> template <typename T>
class MatMulMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> { class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public: public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine,
Tensor* x, bool trans_x, Tensor* y, bool trans_y, paddle::platform::Place cpu_place, Tensor* x,
Tensor* out, float scale, const std::string& uniq_name) bool trans_x, Tensor* y, bool trans_y, Tensor* out,
: platform::MKLDNNHandlerT<T, dnnl::matmul>( float scale, const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()),
uniq_name)) { uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
auto mat_dim_x = math::CreateMatrixDescriptor(x->dims(), 0, trans_x); auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(
auto mat_dim_y = math::CreateMatrixDescriptor(y->dims(), 0, trans_y); x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor(
y->dims(), 0, trans_y);
memory::dim x_bs = mat_dim_x.batch_size_; memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_; memory::dim y_bs = mat_dim_y.batch_size_;
...@@ -149,20 +145,21 @@ constexpr bool IsInt8() { ...@@ -149,20 +145,21 @@ constexpr bool IsInt8() {
template <typename T> template <typename T>
constexpr bool IsBfloat16() { constexpr bool IsBfloat16() {
return std::is_same<T, platform::bfloat16>::value; return std::is_same<T, paddle::platform::bfloat16>::value;
} }
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned. // original x_dim is returned.
static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { static paddle::framework::DDim RowMatrixDimsFromVector(
return x_dim.size() > 1 ? x_dim : framework::make_ddim({1, x_dim[0]}); const paddle::framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
} }
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the // Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned. // original y_dim is returned.
static framework::DDim ColumnMatrixDimsFromVector( static paddle::framework::DDim ColumnMatrixDimsFromVector(
const framework::DDim& y_dim) { const paddle::framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : framework::make_ddim({y_dim[0], 1}); return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
} }
/** /**
...@@ -172,7 +169,7 @@ static framework::DDim ColumnMatrixDimsFromVector( ...@@ -172,7 +169,7 @@ static framework::DDim ColumnMatrixDimsFromVector(
* If transposed, `H,W` will be swapped. * If transposed, `H,W` will be swapped.
*/ */
static void ReshapeTensorToMatrixSequence( static void ReshapeTensorToMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) { Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w; int64_t h, w;
h = descriptor.height_; h = descriptor.height_;
w = descriptor.width_; w = descriptor.width_;
...@@ -200,14 +197,14 @@ static void ReshapeTensorToMatrixSequence( ...@@ -200,14 +197,14 @@ static void ReshapeTensorToMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the * If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize. * BatchSize.
*/ */
static void ReshapeXYOutToMatrixSequence(framework::Tensor* x, static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
framework::Tensor* y, bool trans_x, bool trans_y) {
framework::Tensor* out, bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims()); auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims()); auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); auto mat_dim_x =
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_}); out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else { } else {
...@@ -244,8 +241,7 @@ class MatMulFactory { ...@@ -244,8 +241,7 @@ class MatMulFactory {
}; };
void SetDNNLEngine(const ExecutionContext& ctx) { void SetDNNLEngine(const ExecutionContext& ctx) {
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
engine_ = dev_ctx.GetEngine(); engine_ = dev_ctx.GetEngine();
} }
...@@ -263,19 +259,19 @@ class MatMulFactory { ...@@ -263,19 +259,19 @@ class MatMulFactory {
auto axis_set = std::set<int>(axis.begin(), axis.end()); auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size, PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique.")); "In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(in_rank, axis_size,
in_rank, axis_size, paddle::platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The input dimension's size " "The input dimension's size "
"should be equal to the axis's size. " "should be equal to the axis's size. "
"But received dimension is %d, " "But received dimension is %d, "
"axis's size is %d", "axis's size is %d",
in_rank, axis_size)); in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size, PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1).")); "Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size()); std::vector<int64_t> new_x(x.size());
...@@ -285,8 +281,8 @@ class MatMulFactory { ...@@ -285,8 +281,8 @@ class MatMulFactory {
return new_x; return new_x;
} }
std::pair<math::MatDescriptor, memory::dims> GetInputDimsAndStrides( std::pair<paddle::operators::math::MatDescriptor, memory::dims>
const ExecutionContext& ctx, std::string input_name) { GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name); auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name); auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims(); auto input_dims = ctx.Input<Tensor>(input_name)->dims();
...@@ -297,8 +293,9 @@ class MatMulFactory { ...@@ -297,8 +293,9 @@ class MatMulFactory {
auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector
: ColumnMatrixDimsFromVector; : ColumnMatrixDimsFromVector;
math::MatDescriptor mat_dim = paddle::operators::math::MatDescriptor mat_dim =
math::CreateMatrixDescriptor(MatrixDimsFromVector(new_dims), 0, paddle::operators::math::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>("transpose_" + input_name)); ctx.Attr<bool>("transpose_" + input_name));
memory::dims strides; memory::dims strides;
...@@ -340,17 +337,17 @@ class MatMulFactory { ...@@ -340,17 +337,17 @@ class MatMulFactory {
} }
MatMulDims GetMatmulDims(const ExecutionContext& ctx) { MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
math::MatDescriptor mat_dim_x; paddle::operators::math::MatDescriptor mat_dim_x;
memory::dims strides_x; memory::dims strides_x;
std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
math::MatDescriptor mat_dim_y; paddle::operators::math::MatDescriptor mat_dim_y;
memory::dims strides_y; memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
auto x_bs = mat_dim_x.batch_size_; auto x_bs = mat_dim_x.batch_size_;
auto y_bs = mat_dim_y.batch_size_; auto y_bs = mat_dim_y.batch_size_;
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive," "If batch sizes of X and Y are positive,"
"they have to be equal.")); "they have to be equal."));
...@@ -448,10 +445,10 @@ class MatMulFactory { ...@@ -448,10 +445,10 @@ class MatMulFactory {
} }
void SetOutputFormat(const ExecutionContext& ctx) { void SetOutputFormat(const ExecutionContext& ctx) {
using platform::MKLDNNFormatForSize; using paddle::platform::MKLDNNFormatForSize;
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto format = auto format =
MKLDNNFormatForSize(out->dims().size(), MKLDNNMemoryFormat::nchw); MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format); out->set_format(format);
out->set_layout(DataLayout::kMKLDNN); out->set_layout(DataLayout::kMKLDNN);
} }
...@@ -495,8 +492,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -495,8 +492,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& out_name = ctx.OutputName("Out"); const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0]; const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
std::string key = platform::CreateKey(dev_ctx, batch_size, out_name); std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto factory = auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key)); std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
...@@ -529,26 +526,29 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { ...@@ -529,26 +526,29 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
} }
template <typename T> template <typename T>
class DNNLMatMulKernel : public framework::OpKernel<T> { class DNNLMatMulKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1, ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented( paddle::platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected " "DNNL matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
} }
platform::MKLDNNDeviceContext::tls().log_lib_version(); MKLDNNDeviceContext::tls().log_lib_version();
ExecuteMatMul<T, T>(ctx); ExecuteMatMul<T, T>(ctx);
} }
}; };
} // anonymous namespace
namespace paddle {
namespace operators {
template <typename T> template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const {
public:
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1, ctx.Attr<int>("head_number"), 1,
...@@ -557,16 +557,15 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -557,16 +557,15 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
} }
RunKernel<T>(ctx); RunKernel(ctx);
} }
private: template <typename T>
void ExecuteMatMulGrad(const ExecutionContext& ctx, void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x, const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
bool is_fold_init_dims_y, Tensor* out, Tensor* out, int execution_number) const {
int execution_number) const {
// gradient is calculated in a different way when broadcasting is used // gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2; out->dims().size() == 2;
...@@ -582,9 +581,10 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -582,9 +581,10 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
: FoldFirstAndLastDims<T>(dev_ctx, y); : FoldFirstAndLastDims<T>(dev_ctx, y);
} }
MatMulMKLDNNHandler<T> handler( float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined,
trans_y, out, ctx.Attr<float>("alpha"), MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out, alpha,
ctx.InputName(framework::GradVarName("Out")) + ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number)); std::to_string(execution_number));
...@@ -604,12 +604,12 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -604,12 +604,12 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN); out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( out->set_format(platform::GetMKLDNNFormat(
framework::vectorize<int64_t>(out->dims())))); dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
} }
template <typename Tout = T> template <typename T>
void RunKernel(const ExecutionContext& ctx) const { void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
...@@ -620,10 +620,13 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -620,10 +620,13 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.Attr<bool>("transpose_X"); bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr<bool>("transpose_X")
bool transpose_y = ctx.Attr<bool>("transpose_Y"); : ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims; framework::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx_dims = dx->dims();
...@@ -641,10 +644,10 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -641,10 +644,10 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
} }
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
&dout, true, false, dx, 0); true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
&x, true, false, dy, 1); true, false, dy, 1);
} else if (transpose_x) { } else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0); &dout, true, false, dx, 0);
...@@ -653,37 +656,40 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -653,37 +656,40 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
} else if (transpose_y) { } else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0); &y, false, true, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
&x, false, true, dy, 1); false, true, dy, 1);
} else { } else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0); &y, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
&dout, false, true, dy, 1); false, true, dy, 1);
} }
if (dx) { if (dx) {
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
dx->Resize(dx_dims); dx->Resize(dx_dims);
dx->set_format(x.format());
} }
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
dy->Resize(dy_dims); dy->Resize(dy_dims);
dy->set_format(y.format());
} }
} }
} }
};
template class MatMulGradMKLDNNKernel<float>;
template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
ops::DNNLMatMulKernel<float>, DNNLMatMulKernel<float>,
ops::DNNLMatMulKernel<paddle::platform::bfloat16>, DNNLMatMulKernel<paddle::platform::bfloat16>,
ops::DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<uint8_t>);
ops::DNNLMatMulKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>, ops::MatMulGradMKLDNNKernel<float>,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using platform::MKLDNNDeviceContext;
using framework::ExecutionContext;
using Tensor = framework::Tensor;
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override;
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
} // namespace paddle
...@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace {
namespace operators {
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using framework::DataLayout; using paddle::framework::DataLayout;
using framework::ExecutionContext; using paddle::framework::ExecutionContext;
using platform::GetMKLDNNFormat; using paddle::platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using platform::to_void_cast; using paddle::platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
template <typename T> template <typename T>
class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> { class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine,
std::vector<int64_t>& x_dims, bool trans_x, paddle::platform::Place cpu_place,
std::vector<int64_t>& y_dims, bool trans_y, const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name) const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>( : paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, x_dims, uniq_name)) { paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
// M X K * K X N // M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3; const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2; const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1; const int W_idx = x_dims.size() - 1;
...@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> { ...@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
}; };
template <typename T> template <typename T>
class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public: public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
protected:
void ExecuteMatMul(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
}
private: private:
void CalculateMatrixDims(const ExecutionContext& ctx, void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims, const std::vector<int64_t>& x_dims,
...@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t>& out_dims, Tensor* out) const { std::vector<int64_t>& out_dims, Tensor* out) const {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
} else { } else {
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i]; x_bd_dims[i] = x_dims[i];
...@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
} }
if (y_dims.size() == 1) { if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
} else { } else {
for (size_t i = 0; i < y_dims.size(); ++i) { for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i]; y_bd_dims[i] = y_dims[i];
...@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_dims.size() - 2; ++i) { for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting." "Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but " "Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d", "received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_dims[i], i, y_dims[i])); i, x_dims[i], i, y_dims[i]));
out_dims[i] = std::max(x_dims[i], y_dims[i]); out_dims[i] = std::max(x_dims[i], y_dims[i]);
} }
out->Resize(framework::make_ddim(out_dims)); out->Resize(make_ddim(out_dims));
} }
} }
...@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = framework::vectorize(x->dims()); auto x_dims = vectorize(x->dims());
auto y_dims = framework::vectorize(y->dims()); auto y_dims = vectorize(y->dims());
auto out_dims = framework::vectorize(out->dims()); auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size()); int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
...@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out); out);
MatMulV2MKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
x_bd_dims, trans_x, y_bd_dims, trans_y, trans_x, y, y_bd_dims, trans_y, out, out_dims);
ctx.InputName("X")); }
};
const auto src_memory_p = handler.AcquireSrcMemory(x); template <typename T>
const auto weights_memory_p = handler.AcquireWeightsMemory(y); class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const auto dst_memory_p = handler.AcquireDstMemory(out); public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
auto matmul_p = handler.AcquireForwardPrimitive(); private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
}
}
}
std::unordered_map<int, memory> matmul_args = { dx_tmp->Resize(make_ddim(dx_bd_dims));
{DNNL_ARG_SRC, *src_memory_p}, dx_tmp->mutable_data<T>(ctx.GetPlace());
{DNNL_ARG_WEIGHTS, *weights_memory_p}, dy_tmp->Resize(make_ddim(dy_bd_dims));
{DNNL_ARG_DST, *dst_memory_p}}; dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream(); auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
}
out->set_layout(framework::DataLayout::kMKLDNN); void RunKernel(const ExecutionContext& ctx) const {
out->set_format( const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
return;
}
auto* dout = ctx.Input<Tensor>(GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(GradVarName("X"));
auto* dy = ctx.Output<Tensor>(GradVarName("Y"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);
if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
} }
}; };
} // namespace operators } // anonymous namespace
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2MKLDNNKernel<float>, MatMulV2MKLDNNKernel<float>,
ops::MatMulV2MKLDNNKernel<paddle::platform::bfloat16>); MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
// ops::MatMulV2GradMKLDNNKernel<float>, MatMulV2GradMKLDNNKernel<float>,
// ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>); MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from functools import reduce
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
...@@ -23,14 +24,12 @@ import paddle ...@@ -23,14 +24,12 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
paddle.enable_static()
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
"""Reference forward implementation using np.matmul.""" """Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually # np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately. # transpose X and Y appropriately.
if transpose_X: if transpose_x:
if X.ndim == 1: if X.ndim == 1:
X = X.reshape((X.size, )) X = X.reshape((X.size, ))
elif X.ndim == 2: elif X.ndim == 2:
...@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim = [i for i in range(len(X.shape))] dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim)) X = np.transpose(X, tuple(dim))
if transpose_Y: if transpose_y:
if Y.ndim == 1: if Y.ndim == 1:
Y = Y.reshape((Y.size, )) Y = Y.reshape((Y.size, ))
else: else:
...@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp( ...@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp(
class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (1, 1, 12, 4) self.x_shape = (2, 1, 12, 9)
self.y_shape = (1, 2, 4, 12) self.y_shape = (1, 3, 9, 12)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
...@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2( ...@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2(
class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
TestMatMulV2VectorXVectorOneDNNOp): TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (2, 2, 5, 4) self.x_shape = (2, 2, 7, 4)
self.y_shape = (2, 2, 5, 3) self.y_shape = (2, 2, 7, 5)
self.trans_x = True self.trans_x = True
self.trans_y = False self.trans_y = False
...@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( ...@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp( class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp): TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (3, 1, 6, 5) self.x_shape = (3, 1, 6, 7)
self.y_shape = (1, 2, 6, 9) self.y_shape = (1, 2, 6, 9)
self.trans_x = True self.trans_x = True
self.trans_y = False self.trans_y = False
...@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): ...@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (2, 1, 40) self.x_shape = (2, 1, 100)
self.y_shape = (40) self.y_shape = (100)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
...@@ -245,6 +244,8 @@ def create_bf16_test_class(parent): ...@@ -245,6 +244,8 @@ def create_bf16_test_class(parent):
'X': convert_float_to_uint16(x), 'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y) 'Y': convert_float_to_uint16(y)
} }
self.x_fp32 = x
self.y_fp32 = y
def set_dtype_attr(self): def set_dtype_attr(self):
self.attrs['mkldnn_data_type'] = "bfloat16" self.attrs['mkldnn_data_type'] = "bfloat16"
...@@ -253,7 +254,99 @@ def create_bf16_test_class(parent): ...@@ -253,7 +254,99 @@ def create_bf16_test_class(parent):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
def test_check_grad(self): def test_check_grad(self):
pass self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def matmul_grad(self, x, transpose_x, y, transpose_y):
x = np.transpose(
x, self.shape_transpose_axes[x.ndim]) if transpose_x else x
y = np.transpose(
y, self.shape_transpose_axes[y.ndim]) if transpose_y else y
return np.matmul(x, y)
def calculate_grads(self):
self.shape_transpose_axes = {
2: [1, 0],
3: [0, 2, 1],
4: [0, 1, 3, 2],
5: [0, 1, 2, 4, 3]
}
# expand vector so it will be a valid matrix for multiplication
if self.x_fp32.ndim == 1:
self.x_fp32 = np.expand_dims(self.x_fp32, axis=0)
if self.y_fp32.ndim == 1:
self.y_fp32 = np.expand_dims(self.y_fp32, axis=1)
x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim]
y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim]
x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[
'trans_x'] is True else self.x_fp32
y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[
'trans_y'] is True else self.y_fp32
dout = np.matmul(x, y)
x_shape = x.shape
y_shape = y.shape
if x.ndim <= 2 or y.ndim <= 2:
is_broadcast = False
elif x.ndim != y.ndim:
is_broadcast = True
else:
is_broadcast = x.shape[0:-2] != y.shape[0:-2]
if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True:
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['trans_x'] is True and self.attrs[
'trans_y'] is False:
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['trans_x'] is False and self.attrs[
'trans_y'] is True:
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)
if is_broadcast:
x_reduce_axis = []
y_reduce_axis = []
for index, (
first, second
) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])):
if first != second:
x_reduce_axis.append(index)
for index, (
first, second
) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])):
if first != second:
y_reduce_axis.append(index)
if x_reduce_axis:
self.dx = self.dx.sum(axis=tuple(x_reduce_axis),
keepdims=True)
if y_reduce_axis:
self.dy = self.dy.sum(axis=tuple(y_reduce_axis),
keepdims=True)
# after multiplying with vector one dimension is deleted from tensor
if len(x_shape) == 2 and x_shape[0] == 1:
dout = dout.sum(axis=-2)
if len(y_shape) == 2 and y_shape[1] == 1:
dout = dout.sum(axis=-1)
self.dout = dout
cls_name = "{0}_{1}".format(parent.__name__, "BF16") cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestMatMulV2Bf16OneDNNOp.__name__ = cls_name TestMatMulV2Bf16OneDNNOp.__name__ = cls_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册