未验证 提交 0be71571 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2 BF16/FP32 BWD kernel (#34192)

* test version of matmul_v2

* added matmul_v2 grad kernel

* minor changes

* minor changes

* minor change for CI approval

* CI fix

* CI fix

* trigger CI

* changes after review, not working yet

* moved ops to anonymous namespaces

* changes after review
上级 44e4d57b
......@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel {
}
std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) {
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else {
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
......@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,34 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace platform {
class MKLDNNDeviceContext;
struct CPUPlace;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
using dnnl::memory;
using dnnl::primitive;
using framework::DataLayout;
using framework::ExecutionContext;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::framework::vectorize;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
namespace {
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static framework::Tensor FoldOuterDims(const Tensor& input) {
static Tensor FoldOuterDims(const Tensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
......@@ -52,36 +42,38 @@ static framework::Tensor FoldOuterDims(const Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename T>
static framework::Tensor FoldFirstAndLastDims(
const MKLDNNDeviceContext& dev_ctx, const Tensor* input) {
auto input_dims = framework::vectorize(input->dims());
static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
const Tensor* input) {
auto input_dims = vectorize(input->dims());
if (input_dims.size() != 3) {
return *input;
}
framework::Tensor output;
Tensor output;
output.Resize({input_dims[1], input_dims[0], input_dims[2]});
auto output_dims = framework::vectorize(output.dims());
auto output_dims = vectorize(output.dims());
memory::data_type input_type = framework::ToMKLDNNDataType(input->type());
std::string key = platform::CreateKey(dev_ctx, input_dims, input->format(),
input->format(), input_type);
platform::ReorderMKLDNNHandler reorder_handler(output_dims, input->type(),
input_type, dev_ctx,
dev_ctx.GetEngine(), key);
memory::data_type input_type =
paddle::framework::ToMKLDNNDataType(input->type());
std::string key = paddle::platform::CreateKey(
dev_ctx, input_dims, input->format(), input->format(), input_type);
paddle::platform::ReorderMKLDNNHandler reorder_handler(
output_dims, input->type(), input_type, dev_ctx, dev_ctx.GetEngine(),
key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
memory::format_tag::abc, platform::to_void_cast(input->data<T>()));
memory::format_tag::abc,
paddle::platform::to_void_cast(input->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
&output, memory::format_tag::bac, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
paddle::platform::RecordEvent record_reorder(
"int_reorder", paddle::platform::EventRole::kUniqueOp);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto& astream = MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
......@@ -90,19 +82,23 @@ static framework::Tensor FoldFirstAndLastDims(
}
template <typename T>
class MatMulMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
Tensor* x, bool trans_x, Tensor* y, bool trans_y,
Tensor* out, float scale, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>(
const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale, const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto mat_dim_x = math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y = math::CreateMatrixDescriptor(y->dims(), 0, trans_y);
auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(
x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor(
y->dims(), 0, trans_y);
memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;
......@@ -149,20 +145,21 @@ constexpr bool IsInt8() {
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, platform::bfloat16>::value;
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : framework::make_ddim({1, x_dim[0]});
static paddle::framework::DDim RowMatrixDimsFromVector(
const paddle::framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static framework::DDim ColumnMatrixDimsFromVector(
const framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : framework::make_ddim({y_dim[0], 1});
static paddle::framework::DDim ColumnMatrixDimsFromVector(
const paddle::framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
/**
......@@ -172,7 +169,7 @@ static framework::DDim ColumnMatrixDimsFromVector(
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorToMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) {
Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
......@@ -200,14 +197,14 @@ static void ReshapeTensorToMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_x,
bool trans_y) {
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y);
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
......@@ -244,8 +241,7 @@ class MatMulFactory {
};
void SetDNNLEngine(const ExecutionContext& ctx) {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
engine_ = dev_ctx.GetEngine();
}
......@@ -263,19 +259,19 @@ class MatMulFactory {
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank, axis_size,
platform::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
......@@ -285,8 +281,8 @@ class MatMulFactory {
return new_x;
}
std::pair<math::MatDescriptor, memory::dims> GetInputDimsAndStrides(
const ExecutionContext& ctx, std::string input_name) {
std::pair<paddle::operators::math::MatDescriptor, memory::dims>
GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
......@@ -297,9 +293,10 @@ class MatMulFactory {
auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector
: ColumnMatrixDimsFromVector;
math::MatDescriptor mat_dim =
math::CreateMatrixDescriptor(MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>("transpose_" + input_name));
paddle::operators::math::MatDescriptor mat_dim =
paddle::operators::math::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>("transpose_" + input_name));
memory::dims strides;
if (!shape.empty()) {
......@@ -340,17 +337,17 @@ class MatMulFactory {
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
math::MatDescriptor mat_dim_x;
paddle::operators::math::MatDescriptor mat_dim_x;
memory::dims strides_x;
std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
math::MatDescriptor mat_dim_y;
paddle::operators::math::MatDescriptor mat_dim_y;
memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
auto x_bs = mat_dim_x.batch_size_;
auto y_bs = mat_dim_y.batch_size_;
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive,"
"they have to be equal."));
......@@ -448,10 +445,10 @@ class MatMulFactory {
}
void SetOutputFormat(const ExecutionContext& ctx) {
using platform::MKLDNNFormatForSize;
using paddle::platform::MKLDNNFormatForSize;
auto* out = ctx.Output<Tensor>("Out");
auto format =
MKLDNNFormatForSize(out->dims().size(), MKLDNNMemoryFormat::nchw);
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
}
......@@ -495,8 +492,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
std::string key = platform::CreateKey(dev_ctx, batch_size, out_name);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name);
key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......@@ -529,161 +526,170 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
}
template <typename T>
class DNNLMatMulKernel : public framework::OpKernel<T> {
class DNNLMatMulKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented(
paddle::platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
platform::MKLDNNDeviceContext::tls().log_lib_version();
MKLDNNDeviceContext::tls().log_lib_version();
ExecuteMatMul<T, T>(ctx);
}
};
} // anonymous namespace
namespace paddle {
namespace operators {
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
RunKernel<T>(ctx);
void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
RunKernel(ctx);
}
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Tensor x_combined, y_combined;
if (!need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y);
}
template <typename T>
void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
Tensor* out, int execution_number) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Tensor x_combined, y_combined;
if (!need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y);
}
MatMulMKLDNNHandler<T> handler(
dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined,
trans_y, out, ctx.Attr<float>("alpha"),
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape(
framework::vectorize<int64_t>(out->dims()))));
}
template <typename Tout = T>
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<Tensor>("X");
auto y = *ctx.Input<Tensor>("Y");
auto dout = *ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.Attr<bool>("transpose_X");
bool transpose_y = ctx.Attr<bool>("transpose_Y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out, alpha,
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
}
template <typename T>
void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<Tensor>("X");
auto y = *ctx.Input<Tensor>("Y");
auto dout = *ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr<bool>("transpose_X")
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true,
&dout, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true,
&x, true, false, dy, 1);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true,
&x, false, true, dy, 1);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true,
&dout, false, true, dy, 1);
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
true, false, dy, 1);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
false, true, dy, 1);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
false, true, dy, 1);
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_format(x.format());
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_format(y.format());
}
}
};
}
template class MatMulGradMKLDNNKernel<float>;
template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
ops::DNNLMatMulKernel<float>,
ops::DNNLMatMulKernel<paddle::platform::bfloat16>,
ops::DNNLMatMulKernel<int8_t>,
ops::DNNLMatMulKernel<uint8_t>);
DNNLMatMulKernel<float>,
DNNLMatMulKernel<paddle::platform::bfloat16>,
DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using platform::MKLDNNDeviceContext;
using framework::ExecutionContext;
using Tensor = framework::Tensor;
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override;
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
} // namespace paddle
......@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
namespace paddle {
namespace operators {
namespace {
using dnnl::memory;
using dnnl::primitive;
using framework::DataLayout;
using framework::ExecutionContext;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
template <typename T>
class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
std::vector<int64_t>& x_dims, bool trans_x,
std::vector<int64_t>& y_dims, bool trans_y,
const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>(
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, x_dims, uniq_name)) {
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
......@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
};
template <typename T>
class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
protected:
void ExecuteMatMul(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
}
private:
void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims,
......@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t>& out_dims, Tensor* out) const {
if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i];
......@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
}
if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i];
......@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_dims[i], i, y_dims[i]));
out_dims[i] = std::max(x_dims[i], y_dims[i]);
}
out->Resize(framework::make_ddim(out_dims));
out->Resize(make_ddim(out_dims));
}
}
......@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = framework::vectorize(x->dims());
auto y_dims = framework::vectorize(y->dims());
auto out_dims = framework::vectorize(out->dims());
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
......@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);
MatMulV2MKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(),
x_bd_dims, trans_x, y_bd_dims, trans_y,
ctx.InputName("X"));
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
trans_x, y, y_bd_dims, trans_y, out, out_dims);
}
};
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
auto matmul_p = handler.AcquireForwardPrimitive();
private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
}
}
}
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
dx_tmp->Resize(make_ddim(dx_bd_dims));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(make_ddim(dy_bd_dims));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
}
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
return;
}
auto* dout = ctx.Input<Tensor>(GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(GradVarName("X"));
auto* dy = ctx.Output<Tensor>(GradVarName("Y"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);
if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
}
};
} // namespace operators
} // namespace paddle
} // anonymous namespace
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2MKLDNNKernel<float>,
ops::MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace,
// ops::MatMulV2GradMKLDNNKernel<float>,
// ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
MatMulV2GradMKLDNNKernel<float>,
MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
from functools import reduce
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
......@@ -23,14 +24,12 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
paddle.enable_static()
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if transpose_x:
if X.ndim == 1:
X = X.reshape((X.size, ))
elif X.ndim == 2:
......@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if transpose_y:
if Y.ndim == 1:
Y = Y.reshape((Y.size, ))
else:
......@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp(
class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (1, 1, 12, 4)
self.y_shape = (1, 2, 4, 12)
self.x_shape = (2, 1, 12, 9)
self.y_shape = (1, 3, 9, 12)
self.trans_x = False
self.trans_y = False
......@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2(
class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 2, 5, 4)
self.y_shape = (2, 2, 5, 3)
self.x_shape = (2, 2, 7, 4)
self.y_shape = (2, 2, 7, 5)
self.trans_x = True
self.trans_y = False
......@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (3, 1, 6, 5)
self.x_shape = (3, 1, 6, 7)
self.y_shape = (1, 2, 6, 9)
self.trans_x = True
self.trans_y = False
......@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 1, 40)
self.y_shape = (40)
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
......@@ -245,6 +244,8 @@ def create_bf16_test_class(parent):
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.x_fp32 = x
self.y_fp32 = y
def set_dtype_attr(self):
self.attrs['mkldnn_data_type'] = "bfloat16"
......@@ -253,7 +254,99 @@ def create_bf16_test_class(parent):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def matmul_grad(self, x, transpose_x, y, transpose_y):
x = np.transpose(
x, self.shape_transpose_axes[x.ndim]) if transpose_x else x
y = np.transpose(
y, self.shape_transpose_axes[y.ndim]) if transpose_y else y
return np.matmul(x, y)
def calculate_grads(self):
self.shape_transpose_axes = {
2: [1, 0],
3: [0, 2, 1],
4: [0, 1, 3, 2],
5: [0, 1, 2, 4, 3]
}
# expand vector so it will be a valid matrix for multiplication
if self.x_fp32.ndim == 1:
self.x_fp32 = np.expand_dims(self.x_fp32, axis=0)
if self.y_fp32.ndim == 1:
self.y_fp32 = np.expand_dims(self.y_fp32, axis=1)
x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim]
y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim]
x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[
'trans_x'] is True else self.x_fp32
y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[
'trans_y'] is True else self.y_fp32
dout = np.matmul(x, y)
x_shape = x.shape
y_shape = y.shape
if x.ndim <= 2 or y.ndim <= 2:
is_broadcast = False
elif x.ndim != y.ndim:
is_broadcast = True
else:
is_broadcast = x.shape[0:-2] != y.shape[0:-2]
if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True:
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['trans_x'] is True and self.attrs[
'trans_y'] is False:
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['trans_x'] is False and self.attrs[
'trans_y'] is True:
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)
if is_broadcast:
x_reduce_axis = []
y_reduce_axis = []
for index, (
first, second
) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])):
if first != second:
x_reduce_axis.append(index)
for index, (
first, second
) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])):
if first != second:
y_reduce_axis.append(index)
if x_reduce_axis:
self.dx = self.dx.sum(axis=tuple(x_reduce_axis),
keepdims=True)
if y_reduce_axis:
self.dy = self.dy.sum(axis=tuple(y_reduce_axis),
keepdims=True)
# after multiplying with vector one dimension is deleted from tensor
if len(x_shape) == 2 and x_shape[0] == 1:
dout = dout.sum(axis=-2)
if len(y_shape) == 2 and y_shape[1] == 1:
dout = dout.sum(axis=-1)
self.dout = dout
cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestMatMulV2Bf16OneDNNOp.__name__ = cls_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册