未验证 提交 0be71571 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2 BF16/FP32 BWD kernel (#34192)

* test version of matmul_v2

* added matmul_v2 grad kernel

* minor changes

* minor changes

* minor change for CI approval

* CI fix

* CI fix

* trigger CI

* changes after review, not working yet

* moved ops to anonymous namespaces

* changes after review
上级 44e4d57b
......@@ -62,10 +62,15 @@ class MatMulV2Op : public framework::OperatorWithKernel {
}
std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) {
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else {
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
......@@ -169,10 +174,17 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using platform::MKLDNNDeviceContext;
using framework::ExecutionContext;
using Tensor = framework::Tensor;
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override;
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
} // namespace paddle
......@@ -12,37 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
namespace paddle {
namespace operators {
namespace {
using dnnl::memory;
using dnnl::primitive;
using framework::DataLayout;
using framework::ExecutionContext;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
using paddle::framework::vectorize;
using paddle::framework::make_ddim;
using paddle::framework::GradVarName;
template <typename T>
class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
std::vector<int64_t>& x_dims, bool trans_x,
std::vector<int64_t>& y_dims, bool trans_y,
const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>(
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, x_dims, uniq_name)) {
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
......@@ -104,10 +108,44 @@ class MatMulV2MKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
};
template <typename T>
class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
class MatMulV2MKLDNNKernel
: public paddle::operators::MatMulGradMKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
protected:
void ExecuteMatMul(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
paddle::platform::Place cpu_place, const Tensor* x,
std::vector<int64_t>& x_dims, bool trans_x,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
}
private:
void CalculateMatrixDims(const ExecutionContext& ctx,
const std::vector<int64_t>& x_dims,
......@@ -117,6 +155,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
std::vector<int64_t>& out_dims, Tensor* out) const {
if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i];
......@@ -124,6 +165,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
}
if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i];
......@@ -134,14 +178,14 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
platform::errors::InvalidArgument(
paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_dims[i], i, y_dims[i]));
out_dims[i] = std::max(x_dims[i], y_dims[i]);
}
out->Resize(framework::make_ddim(out_dims));
out->Resize(make_ddim(out_dims));
}
}
......@@ -155,9 +199,9 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = framework::vectorize(x->dims());
auto y_dims = framework::vectorize(y->dims());
auto out_dims = framework::vectorize(out->dims());
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
......@@ -168,38 +212,166 @@ class MatMulV2MKLDNNKernel : public framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims,
out);
MatMulV2MKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(),
x_bd_dims, trans_x, y_bd_dims, trans_y,
ctx.InputName("X"));
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims,
trans_x, y, y_bd_dims, trans_y, out, out_dims);
}
};
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
template <typename T>
class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); }
auto matmul_p = handler.AcquireForwardPrimitive();
private:
void CalculateGradMatrixDims(const ExecutionContext& ctx, Tensor* dx_tmp,
Tensor* dy_tmp,
const std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& dy_dims,
std::vector<int64_t>& dx_bd_dims,
std::vector<int64_t>& dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
dx_bd_dims[i] = dy_dims[i];
} else {
dy_bd_dims[i] = dx_dims[i];
}
}
}
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
dx_tmp->Resize(make_ddim(dx_bd_dims));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(make_ddim(dy_bd_dims));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
}
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast =
!std::equal(x_dims.cbegin(), x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
paddle::operators::MatMulGradMKLDNNKernel<T>::Compute(ctx);
return;
}
auto* dout = ctx.Input<Tensor>(GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(GradVarName("X"));
auto* dy = ctx.Output<Tensor>(GradVarName("Y"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, dx_bd_dims,
dy_bd_dims);
if (trans_x && trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, true, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims,
2);
} else if (trans_x) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y,
y_dims, false, dout, dout_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, false, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else if (trans_y) {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, false, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, true, x, x_dims, false, &dy_tmp,
dy_bd_dims, 2);
} else {
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout,
dout_dims, false, y, y_dims, true, &dx_tmp,
dx_bd_dims, 1);
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x,
x_dims, true, dout, dout_dims, false, &dy_tmp,
dy_bd_dims, 2);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
}
};
} // namespace operators
} // namespace paddle
} // anonymous namespace
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulV2MKLDNNKernel<float>,
ops::MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
// REGISTER_OP_KERNEL(matmul_grad_v2, MKLDNN, ::paddle::platform::CPUPlace,
// ops::MatMulV2GradMKLDNNKernel<float>,
// ops::MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, ::paddle::platform::CPUPlace,
MatMulV2GradMKLDNNKernel<float>,
MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
from functools import reduce
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
......@@ -23,14 +24,12 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
paddle.enable_static()
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if transpose_x:
if X.ndim == 1:
X = X.reshape((X.size, ))
elif X.ndim == 2:
......@@ -39,7 +38,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if transpose_y:
if Y.ndim == 1:
Y = Y.reshape((Y.size, ))
else:
......@@ -144,8 +143,8 @@ class TestMatMulV2MatrixXMatrixTransposeYOneDNNOp(
class TestMatMulV2MatrixXMatrix2OneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (1, 1, 12, 4)
self.y_shape = (1, 2, 4, 12)
self.x_shape = (2, 1, 12, 9)
self.y_shape = (1, 3, 9, 12)
self.trans_x = False
self.trans_y = False
......@@ -170,8 +169,8 @@ class TestMatMulV2MatrixXMatrixTranposeXOneDNNOp2(
class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 2, 5, 4)
self.y_shape = (2, 2, 5, 3)
self.x_shape = (2, 2, 7, 4)
self.y_shape = (2, 2, 7, 5)
self.trans_x = True
self.trans_y = False
......@@ -179,7 +178,7 @@ class TestMatMulV2MatrixXMatrixTranposeX2OneDNNOp3(
class TestMatMulV2MatrixXMatrixTransposeX3OneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (3, 1, 6, 5)
self.x_shape = (3, 1, 6, 7)
self.y_shape = (1, 2, 6, 9)
self.trans_x = True
self.trans_y = False
......@@ -203,8 +202,8 @@ class TestMatMulV2VectorXMatrix5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
class TestMatMulV2Matrix3DXVectorOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 1, 40)
self.y_shape = (40)
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
......@@ -245,6 +244,8 @@ def create_bf16_test_class(parent):
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.x_fp32 = x
self.y_fp32 = y
def set_dtype_attr(self):
self.attrs['mkldnn_data_type'] = "bfloat16"
......@@ -253,7 +254,99 @@ def create_bf16_test_class(parent):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def matmul_grad(self, x, transpose_x, y, transpose_y):
x = np.transpose(
x, self.shape_transpose_axes[x.ndim]) if transpose_x else x
y = np.transpose(
y, self.shape_transpose_axes[y.ndim]) if transpose_y else y
return np.matmul(x, y)
def calculate_grads(self):
self.shape_transpose_axes = {
2: [1, 0],
3: [0, 2, 1],
4: [0, 1, 3, 2],
5: [0, 1, 2, 4, 3]
}
# expand vector so it will be a valid matrix for multiplication
if self.x_fp32.ndim == 1:
self.x_fp32 = np.expand_dims(self.x_fp32, axis=0)
if self.y_fp32.ndim == 1:
self.y_fp32 = np.expand_dims(self.y_fp32, axis=1)
x_transpose_axes = self.shape_transpose_axes[self.x_fp32.ndim]
y_transpose_axes = self.shape_transpose_axes[self.y_fp32.ndim]
x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[
'trans_x'] is True else self.x_fp32
y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[
'trans_y'] is True else self.y_fp32
dout = np.matmul(x, y)
x_shape = x.shape
y_shape = y.shape
if x.ndim <= 2 or y.ndim <= 2:
is_broadcast = False
elif x.ndim != y.ndim:
is_broadcast = True
else:
is_broadcast = x.shape[0:-2] != y.shape[0:-2]
if self.attrs['trans_x'] is True and self.attrs['trans_y'] is True:
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['trans_x'] is True and self.attrs[
'trans_y'] is False:
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['trans_x'] is False and self.attrs[
'trans_y'] is True:
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)
if is_broadcast:
x_reduce_axis = []
y_reduce_axis = []
for index, (
first, second
) in enumerate(zip(x_shape[0:-2], self.dx.shape[0:-2])):
if first != second:
x_reduce_axis.append(index)
for index, (
first, second
) in enumerate(zip(y_shape[0:-2], self.dy.shape[0:-2])):
if first != second:
y_reduce_axis.append(index)
if x_reduce_axis:
self.dx = self.dx.sum(axis=tuple(x_reduce_axis),
keepdims=True)
if y_reduce_axis:
self.dy = self.dy.sum(axis=tuple(y_reduce_axis),
keepdims=True)
# after multiplying with vector one dimension is deleted from tensor
if len(x_shape) == 2 and x_shape[0] == 1:
dout = dout.sum(axis=-2)
if len(y_shape) == 2 and y_shape[1] == 1:
dout = dout.sum(axis=-1)
self.dout = dout
cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestMatMulV2Bf16OneDNNOp.__name__ = cls_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册