未验证 提交 1e8432f2 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Fix for matmul_v2 oneDNN op broadcasting when inputs dims have...

[CHERRY-PICK] Fix for matmul_v2 oneDNN op broadcasting when inputs dims have different lengths (#38733)

* fix for matmul_v2 broadcasting

* resolved conflicts
上级 457fe72c
......@@ -152,7 +152,7 @@ class MatMulV2MKLDNNKernel
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i];
x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
......@@ -162,20 +162,21 @@ class MatMulV2MKLDNNKernel
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i];
y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i];
}
}
if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) {
for (size_t i = 0; i < x_dims.size() - 2; ++i) {
if (x_dims.size() > 2 && y_dims.size() > 2) {
for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_dims[i], i, y_dims[i]));
out_dims[i] = std::max(x_dims[i], y_dims[i]);
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 ||
y_bd_dims[i] == 1,
true, paddle::platform::errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i, x_bd_dims[i], i, y_bd_dims[i]));
out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]);
}
out->Resize(make_ddim(out_dims));
}
......@@ -237,11 +238,11 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine onednn_engine,
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
void ReduceSumForMatmulGradOutput(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t>& dx_dims,
const std::vector<int64_t>& squeezed_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims);
......@@ -257,6 +258,19 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_format(paddle::platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(squeezed_dims)));
}
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t>& dims,
int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
return new_dims;
}
void RunKernel(const ExecutionContext& ctx) const {
......@@ -295,8 +309,14 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y");
auto dout_dims = vectorize(dout->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max(ndims, 3);
size_t ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
} else if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
......@@ -340,21 +360,21 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx,
x_dims);
x_dims,
paddle::framework::vectorize(x->dims()));
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy,
y_dims);
y_dims,
paddle::framework::vectorize(y->dims()));
} else {
*dy = std::move(dy_tmp);
}
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(x->format());
dy->set_layout(paddle::framework::DataLayout::kMKLDNN);
dy->set_format(y->format());
dx->Resize(x->dims());
dy->Resize(y->dims());
}
};
} // anonymous namespace
......
......@@ -251,6 +251,41 @@ class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
self.trans_y = False
class TestMatMulV2MatrixXMatrix4Dx3DTransposeXOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (5, 4, 15, 10)
self.y_shape = (1, 15, 20)
self.trans_x = True
self.trans_y = False
class TestMatMulV2MatrixXMatrix3Dx4DTransposeYOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (2, 10, 15)
self.y_shape = (4, 2, 20, 15)
self.trans_x = False
self.trans_y = True
class TestMatMulV2MatrixXMatrix5Dx3DTransposeXTransposeYOneDNNOp(
TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (4, 3, 2, 15, 10)
self.y_shape = (1, 20, 15)
self.trans_x = True
self.trans_y = True
class TestMatMulV2MatrixXMatrix3Dx4DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (1, 1, 32, 16)
self.y_shape = (16, 16, 16)
self.trans_x = False
self.trans_y = False
# BF16 TESTS
def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册