未验证 提交 ce6a27d9 编写于 作者: J jakpiase 提交者: GitHub

fix for matmul_v2 6D x 2D (#36379)

上级 a6868c91
......@@ -148,8 +148,8 @@ class MatMulV2MKLDNNKernel
if (x_dims.size() == 1) {
x_bd_dims[x_bd_dims.size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
x_bd_dims[2] = x_dims[1];
x_bd_dims[1] = x_dims[0];
x_bd_dims[x_bd_dims.size() - 1] = x_dims[1];
x_bd_dims[x_bd_dims.size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
x_bd_dims[i] = x_dims[i];
......@@ -158,8 +158,8 @@ class MatMulV2MKLDNNKernel
if (y_dims.size() == 1) {
y_bd_dims[x_bd_dims.size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
y_bd_dims[2] = y_dims[1];
y_bd_dims[1] = y_dims[0];
y_bd_dims[y_bd_dims.size() - 1] = y_dims[1];
y_bd_dims[y_bd_dims.size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
y_bd_dims[i] = y_dims[i];
......
......@@ -235,6 +235,22 @@ class TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp(
self.trans_y = True
class TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (1, 1, 2, 1, 8, 9)
self.y_shape = (9, 12)
self.trans_x = False
self.trans_y = False
class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp):
def config(self):
self.x_shape = (20, 5)
self.y_shape = (1, 2, 1, 5, 11)
self.trans_x = False
self.trans_y = False
# BF16 TESTS
def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
......@@ -274,7 +290,8 @@ def create_bf16_test_class(parent):
2: [1, 0],
3: [0, 2, 1],
4: [0, 1, 3, 2],
5: [0, 1, 2, 4, 3]
5: [0, 1, 2, 4, 3],
6: [0, 1, 2, 3, 5, 4]
}
# expand vector so it will be a valid matrix for multiplication
......@@ -370,6 +387,8 @@ create_bf16_test_class(TestMatMulV2Matrix3DXVectorOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeXTransposeYOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeY2OneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp)
if __name__ == "__main__":
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册