未验证 提交 0be71571 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2 BF16/FP32 BWD kernel (#34192)

* test version of matmul_v2

* added matmul_v2 grad kernel

* minor changes

* minor changes

* minor change for CI approval

* CI fix

* CI fix

* trigger CI

* changes after review, not working yet

* moved ops to anonymous namespaces

* changes after review
上级 44e4d57b
...@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel {
} }
std::vector<int64_t> new_dims; std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) { if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2); new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else { } else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2); new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
} }
if (!x_broadcasted) { if (!x_broadcasted) {
new_dims.push_back(M); new_dims.push_back(M);
...@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(
return framework::OpKernelType( ctx, framework::GradVarName("Out"));
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace()); #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using platform::MKLDNNDeviceContext;
using framework::ExecutionContext;
using Tensor = framework::Tensor;
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override;
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
} // namespace paddle
...@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace {
namespace operators {
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using framework::DataLayout; using paddle::framework::DataLayout;
using framework::ExecutionContext; using paddle::framework::ExecutionContext;
using platform::GetMKLDNNFormat; using paddle::platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using platform::to_void_cast; using paddle::platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
template <typename T> template <typename T>
class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> { class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine,
std::vector<int64_t>& x_dims, bool trans_x, paddle::platform::Place cpu_place,
std::vector<int64_t>& y_dims, bool trans_y, const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name) const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>( : paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, x_dims, uniq_name)) { paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
// M X K * K X N // M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3; const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2; const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1; const int W_idx = x_dims.size() - 1;
...@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> { ...@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
}; };
template <typename T> template <typename T>
class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public: public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
protected:
void ExecuteMatMul(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
}
private: private:
void CalculateMatrixDims(const ExecutionContext& ctx, void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims, const std::vector<int64_t>& x_dims,
...@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t>& out_dims, Tensor* out) const { std::vector<int64_t>& out_dims, Tensor* out) const {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
} else { } else {
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i]; x_bd_dims[i] = x_dims[i];
...@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
} }
if (y_dims.size() == 1) { if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
} else { } else {
for (size_t i = 0; i < y_dims.size(); ++i) { for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i]; y_bd_dims[i] = y_dims[i];
...@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_dims.size() - 2; ++i) { for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting." "Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but " "Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d", "received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_dims[i], i, y_dims[i])); i, x_dims[i], i, y_dims[i]));
out_dims[i] = std::max(x_dims[i], y_dims[i]); out_dims[i] = std::max(x_dims[i], y_dims[i]);
} }
out->Resize(framework::make_ddim(out_dims)); out->Resize(make_ddim(out_dims));
} }
} }
...@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = framework::vectorize(x->dims()); auto x_dims = vectorize(x->dims());
auto y_dims = framework::vectorize(y->dims()); auto y_dims = vectorize(y->dims());
auto out_dims = framework::vectorize(out->dims()); auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size()); int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
...@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> { ...@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out); out);
MatMulV2MKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
x_bd_dims, trans_x, y_bd_dims, trans_y, trans_x, y, y_bd_dims, trans_y, out, out_dims);
ctx.InputName("X")); }
};
const auto src_memory_p = handler.AcquireSrcMemory(x); template <typename T>
const auto weights_memory_p = handler.AcquireWeightsMemory(y); class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const auto dst_memory_p = handler.AcquireDstMemory(out); public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
auto matmul_p = handler.AcquireForwardPrimitive(); private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
}
}
}
std::unordered_map<int, memory> matmul_args = { dx_tmp->Resize(make_ddim(dx_bd_dims));
{DNNL_ARG_SRC, *src_memory_p}, dx_tmp->mutable_data<T>(ctx.GetPlace());
{DNNL_ARG_WEIGHTS, *weights_memory_p}, dy_tmp->Resize(make_ddim(dy_bd_dims));
{DNNL_ARG_DST, *dst_memory_p}}; dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream(); auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
}
out->set_layout(framework::DataLayout::kMKLDNN); void RunKernel(const ExecutionContext& ctx) const {
out->set_format( const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims))); const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
return;
}
auto* dout = ctx.Input<Tensor>(GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(GradVarName("X"));
auto* dy = ctx.Output<Tensor>(GradVarName("Y"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);
if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
} }
}; };
} // namespace operators } // anonymous namespace
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2MKLDNNKernel<float>, MatMulV2MKLDNNKernel<float>,
ops::MatMulV2MKLDNNKernel<paddle::platform::bfloat16>); MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
// ops::MatMulV2GradMKLDNNKernel<float>, MatMulV2GradMKLDNNKernel<float>,
// ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>); MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from functools import reduce
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
...@@ -23,14 +24,12 @@ import paddle ...@@ -23,14 +24,12 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
paddle.enable_static()
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
"""Reference forward implementation using np.matmul.""" """Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually # np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately. # transpose X and Y appropriately.
if transpose_X: if transpose_x:
if X.ndim == 1: if X.ndim == 1:
X = X.reshape((X.size, )) X = X.reshape((X.size, ))
elif X.ndim == 2: elif X.ndim == 2:
...@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim = [i for i in range(len(X.shape))] dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim)) X = np.transpose(X, tuple(dim))
if transpose_Y: if transpose_y:
if Y.ndim == 1: if Y.ndim == 1:
Y = Y.reshape((Y.size, )) Y = Y.reshape((Y.size, ))
else: else:
...@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp( ...@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp(
class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (1, 1, 12, 4) self.x_shape = (2, 1, 12, 9)
self.y_shape = (1, 2, 4, 12) self.y_shape = (1, 3, 9, 12)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
...@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2( ...@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2(
class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
TestMatMulV2VectorXVectorOneDNNOp): TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (2, 2, 5, 4) self.x_shape = (2, 2, 7, 4)
self.y_shape = (2, 2, 5, 3) self.y_shape = (2, 2, 7, 5)
self.trans_x = True self.trans_x = True
self.trans_y = False self.trans_y = False
...@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3( ...@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp( class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp): TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (3, 1, 6, 5) self.x_shape = (3, 1, 6, 7)
self.y_shape = (1, 2, 6, 9) self.y_shape = (1, 2, 6, 9)
self.trans_x = True self.trans_x = True
self.trans_y = False self.trans_y = False
...@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): ...@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self): def config(self):
self.x_shape = (2, 1, 40) self.x_shape = (2, 1, 100)
self.y_shape = (40) self.y_shape = (100)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
...@@ -245,6 +244,8 @@ def create_bf16_test_class(parent): ...@@ -245,6 +244,8 @@ def create_bf16_test_class(parent):
'X': convert_float_to_uint16(x), 'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y) 'Y': convert_float_to_uint16(y)
} }
self.x_fp32 = x
self.y_fp32 = y
def set_dtype_attr(self): def set_dtype_attr(self):
self.attrs['mkldnn_data_type'] = "bfloat16" self.attrs['mkldnn_data_type'] = "bfloat16"
...@@ -253,7 +254,99 @@ def create_bf16_test_class(parent): ...@@ -253,7 +254,99 @@ def create_bf16_test_class(parent):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
def test_check_grad(self): def test_check_grad(self):
pass self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def matmul_grad(self, x, transpose_x, y, transpose_y):
x = np.transpose(
x, self.shape_transpose_axes[x.ndim]) if transpose_x else x
y = np.transpose(
y, self.shape_transpose_axes[y.ndim]) if transpose_y else y
return np.matmul(x, y)
def calculate_grads(self):
self.shape_transpose_axes = {
2: [1, 0],
3: [0, 2, 1],
4: [0, 1, 3, 2],
5: [0, 1, 2, 4, 3]
}
# expand vector so it will be a valid matrix for multiplication
if self.x_fp32.ndim == 1:
self.x_fp32 = np.expand_dims(self.x_fp32, axis=0)
if self.y_fp32.ndim == 1:
self.y_fp32 = np.expand_dims(self.y_fp32, axis=1)
x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim]
y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim]
x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[
'trans_x'] is True else self.x_fp32
y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[
'trans_y'] is True else self.y_fp32
dout = np.matmul(x, y)
x_shape = x.shape
y_shape = y.shape
if x.ndim <= 2 or y.ndim <= 2:
is_broadcast = False
elif x.ndim != y.ndim:
is_broadcast = True
else:
is_broadcast = x.shape[0:-2] != y.shape[0:-2]
if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True:
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['trans_x'] is True and self.attrs[
'trans_y'] is False:
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['trans_x'] is False and self.attrs[
'trans_y'] is True:
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)
if is_broadcast:
x_reduce_axis = []
y_reduce_axis = []
for index, (
first, second
) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])):
if first != second:
x_reduce_axis.append(index)
for index, (
first, second
) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])):
if first != second:
y_reduce_axis.append(index)
if x_reduce_axis:
self.dx = self.dx.sum(axis=tuple(x_reduce_axis),
keepdims=True)
if y_reduce_axis:
self.dy = self.dy.sum(axis=tuple(y_reduce_axis),
keepdims=True)
# after multiplying with vector one dimension is deleted from tensor
if len(x_shape) == 2 and x_shape[0] == 1:
dout = dout.sum(axis=-2)
if len(y_shape) == 2 and y_shape[1] == 1:
dout = dout.sum(axis=-1)
self.dout = dout
cls_name = "{0}_{1}".format(parent.__name__, "BF16") cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestMatMulV2Bf16OneDNNOp.__name__ = cls_name TestMatMulV2Bf16OneDNNOp.__name__ = cls_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册