diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h index 7048e11e6f27a075892c28681a3c4913a27b3f3e..98a98e4b994bad5efb361505c08cbb53b78e2b08 100644 --- a/paddle/operators/math/matmul.h +++ b/paddle/operators/math/matmul.h @@ -41,10 +41,26 @@ class MatMulFunctor { "Input tensor a must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_b.size(), 1, "Input tensor b must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_a.size(), 3, - "Input tensor a must be at most 3-dimensional."); - PADDLE_ENFORCE_LE(dim_b.size(), 3, - "Input tensor b must be at most 3-dimensional."); + PADDLE_ENFORCE_LE(dim_a.size(), 4, + "Input tensor a must be at most 4-dimensional."); + PADDLE_ENFORCE_LE(dim_b.size(), 4, + "Input tensor b must be at most 4-dimensional."); + + std::vector out_dim; + int64_t batch_count = 1; + if (dim_a.size() > 3) { + PADDLE_ENFORCE(dim_b.size() > 3, + "The dimensions of X and Y must be the same, and both of " + "them should be 4-dimensional."); + for (int j = 0; j < dim_a.size() - 2; ++j) { + PADDLE_ENFORCE( + dim_b[j] == dim_a[j], + "The dimensions of X and Y must be the same, and both of " + "them should be 4-dimensional."); + out_dim.push_back(dim_a[j]); + batch_count *= dim_a[j]; + } + } int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, strideA = 0, strideB = 0; @@ -67,7 +83,11 @@ class MatMulFunctor { strideA = M * kA; break; default: - assert(false); + batchCountA = batch_count; + size_t mat_s = dim_a.size() - 2; + M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s]; + kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1]; + strideA = M * kA; } switch (dim_b.size()) { @@ -88,7 +108,11 @@ class MatMulFunctor { strideB = kB * N; break; default: - assert(false); + batchCountB = batch_count; + size_t mat_s = dim_b.size() - 2; + kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s]; + N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1]; + strideB = kB * N; } PADDLE_ENFORCE_EQ( diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index fd65d894d5749c97f860d614de354e89f6d9441d..155346db41d8310e6095642db9f34a3b283a041a 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/matmul_op.h" +#include namespace paddle { namespace operators { @@ -41,10 +42,26 @@ class MatMulOp : public framework::OperatorWithKernel { "Input tensor X must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_y.size(), 1, "Input tensor Y must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_x.size(), 3, - "Input tensor X must be at most 3-dimensional."); - PADDLE_ENFORCE_LE(dim_y.size(), 3, - "Input tensor Y must be at most 3-dimensional."); + PADDLE_ENFORCE_LE(dim_x.size(), 4, + "Input tensor X must be at most 4-dimensional."); + PADDLE_ENFORCE_LE(dim_y.size(), 4, + "Input tensor Y must be at most 4-dimensional."); + + std::vector out_dim; + int64_t batch_count = 1; + if (dim_x.size() > 3) { + PADDLE_ENFORCE(dim_y.size() == dim_x.size(), + "The dimensions of X and Y must be the same, and both of " + "them should be 4-dimensional."); + for (int j = 0; j < dim_x.size() - 2; ++j) { + PADDLE_ENFORCE( + dim_y[j] == dim_x[j], + "The dimensions of X and Y must be the same, and both of " + "them should be 4-dimensional."); + out_dim.push_back(dim_x[j]); + batch_count *= dim_x[j]; + } + } int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; bool remove_initial_dim = false, remove_final_dim = false; @@ -70,7 +87,11 @@ class MatMulOp : public framework::OperatorWithKernel { KX = transpose_x ? dim_x[1] : dim_x[2]; break; default: - assert(false); + batchCountX = batch_count; + size_t mat_s = dim_x.size() - 2; + M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; + KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; + break; } switch (dim_y.size()) { @@ -94,7 +115,10 @@ class MatMulOp : public framework::OperatorWithKernel { N = transpose_y ? dim_y[1] : dim_y[2]; break; default: - assert(false); + batchCountY = batch_count; + size_t mat_s = dim_y.size() - 2; + KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; + N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; } PADDLE_ENFORCE_EQ( @@ -110,7 +134,11 @@ class MatMulOp : public framework::OperatorWithKernel { std::vector dim_out; if (batchCount) { - dim_out.push_back(batchCount); + if (dim_x.size() > 3) { + dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); + } else { + dim_out.push_back(batchCount); + } } if (!remove_initial_dim) { dim_out.push_back(M); diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 78adc64f76f45afce64c49bcf734647e0db2d6b3..11266db4b9a7e31b42c56cc4a9ea7175f555786a 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -149,7 +149,10 @@ class MatMulGradKernel : public framework::OpKernel { M = transpose_x ? x_dims[2] : x_dims[1]; break; default: - assert(false); + batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1, + std::multiplies()); + size_t mat_s = x_dims.size() - 2; + M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; } switch (y_dims.size()) { @@ -161,7 +164,10 @@ class MatMulGradKernel : public framework::OpKernel { N = transpose_y ? y_dims[1] : y_dims[2]; break; default: - assert(false); + batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1, + std::multiplies()); + size_t mat_s = y_dims.size() - 2; + N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; } if (batchCountX && batchCountY) { PADDLE_ENFORCE_EQ( @@ -172,7 +178,13 @@ class MatMulGradKernel : public framework::OpKernel { int batchCount = std::max(batchCountX, batchCountY); std::vector dout_dims = {M, N}; if (batchCount) { - dout_dims.insert(dout_dims.begin(), batchCount); + if (x_dims.size() > 3) { + dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); + } else if (y_dims.size() > 3) { + dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2); + } else { + dout_dims.insert(dout_dims.begin(), batchCount); + } } Tensor X = Reshape(x, make_ddim(x_dims)); Tensor Y = Reshape(y, make_ddim(y_dims)); diff --git a/python/paddle/v2/fluid/tests/test_matmul_op.py b/python/paddle/v2/fluid/tests/test_matmul_op.py index f7dc4e053217dcceaf9c64e3605286fc0698593b..0548a9bbfecca4ef2a4f7bfd2a20161dabbced0d 100644 --- a/python/paddle/v2/fluid/tests/test_matmul_op.py +++ b/python/paddle/v2/fluid/tests/test_matmul_op.py @@ -128,5 +128,22 @@ for dim_X in [1, 2, 3]: }) globals()[test_name] = test_class +# Test case 4-dim +dim_X = 4 +dim_Y = 4 +transpose_X = False +transpose_Y = False +test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_X, dim_Y, transpose_X, transpose_Y)) + +shape_X = [2, 2, 2, 3] +shape_Y = [2, 2, 3, 4] +test_class = type(test_name, (Generator, OpTest), { + 'shape_X': shape_X, + 'shape_Y': shape_Y, + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, +}) + if __name__ == "__main__": unittest.main()