diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h index 7048e11e6f27a075892c28681a3c4913a27b3f3e..ae7f1fe9be5066a0ac0ac522849d481fc66a19be 100644 --- a/paddle/operators/math/matmul.h +++ b/paddle/operators/math/matmul.h @@ -41,10 +41,24 @@ class MatMulFunctor { "Input tensor a must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_b.size(), 1, "Input tensor b must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_a.size(), 3, - "Input tensor a must be at most 3-dimensional."); - PADDLE_ENFORCE_LE(dim_b.size(), 3, - "Input tensor b must be at most 3-dimensional."); + + std::vector out_dim; + int64_t batch_count = 1; + if (dim_a.size() > 3) { + PADDLE_ENFORCE(dim_b.size() == dim_a.size(), + "The dimensions of X and Y must be the same, and both of " + "them should be %d-dimensional.", + dim_b.size()); + // The first rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. + for (int j = 0; j < dim_a.size() - 2; ++j) { + PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j], + "The %d-th dimension of X and Y must be the same.", + j); + out_dim.push_back(dim_a[j]); + batch_count *= dim_a[j]; + } + } int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, strideA = 0, strideB = 0; @@ -67,7 +81,11 @@ class MatMulFunctor { strideA = M * kA; break; default: - assert(false); + batchCountA = batch_count; + size_t mat_s = dim_a.size() - 2; + M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s]; + kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1]; + strideA = M * kA; } switch (dim_b.size()) { @@ -88,7 +106,11 @@ class MatMulFunctor { strideB = kB * N; break; default: - assert(false); + batchCountB = batch_count; + size_t mat_s = dim_b.size() - 2; + kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s]; + N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1]; + strideB = kB * N; } PADDLE_ENFORCE_EQ( diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index fd65d894d5749c97f860d614de354e89f6d9441d..3336978c8d8d94c69b986970274661a01ad2161d 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -41,10 +41,26 @@ class MatMulOp : public framework::OperatorWithKernel { "Input tensor X must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_y.size(), 1, "Input tensor Y must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_x.size(), 3, - "Input tensor X must be at most 3-dimensional."); - PADDLE_ENFORCE_LE(dim_y.size(), 3, - "Input tensor Y must be at most 3-dimensional."); + + std::vector out_dim; + int64_t batch_count = 1; + if (dim_x.size() > 3) { + PADDLE_ENFORCE_EQ( + dim_y.size(), dim_x.size(), + "The dimensions of X and Y must be the same, and both of " + "them should be %d-dimensional.", + dim_x.size()); + + // The first rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. + for (int j = 0; j < dim_x.size() - 2; ++j) { + PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], + "The %d-th dimension of X and Y must be the same.", + j); + out_dim.push_back(dim_x[j]); + batch_count *= dim_x[j]; + } + } int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; bool remove_initial_dim = false, remove_final_dim = false; @@ -70,7 +86,11 @@ class MatMulOp : public framework::OperatorWithKernel { KX = transpose_x ? dim_x[1] : dim_x[2]; break; default: - assert(false); + batchCountX = batch_count; + size_t mat_s = dim_x.size() - 2; + M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; + KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; + break; } switch (dim_y.size()) { @@ -94,7 +114,10 @@ class MatMulOp : public framework::OperatorWithKernel { N = transpose_y ? dim_y[1] : dim_y[2]; break; default: - assert(false); + batchCountY = batch_count; + size_t mat_s = dim_y.size() - 2; + KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; + N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; } PADDLE_ENFORCE_EQ( @@ -110,7 +133,11 @@ class MatMulOp : public framework::OperatorWithKernel { std::vector dim_out; if (batchCount) { - dim_out.push_back(batchCount); + if (dim_x.size() > 3) { + dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); + } else { + dim_out.push_back(batchCount); + } } if (!remove_initial_dim) { dim_out.push_back(M); @@ -162,10 +189,14 @@ Examples without transpose: - X: [B, M, K], Y: [K] => Out: [B, M] - X: [M, K], Y: [B, K, N] => Out: [B, M, N] - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] +- X: [B, ..., M, K], Y: [B, ..., K, N] => Out: [B, ..., M, N] The behavior is designed to be similar to the `numpy.matmul` function. The differences are: -- Currently only rank 1 to rank 3 input tensors are supported. +- When the rank of the input data is less than or equal to 3, it + is similar to the `numpy.matmul` function. +- When the rank of the input is greater than 3, the rank of X and + Y must be equal, and the first `rank - 2` dimensions must be equal. - We add `transpose_X` and `transpose_Y` flags. Both the input `X` and `Y` can carry the LoD (Level of Details) information, diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 78adc64f76f45afce64c49bcf734647e0db2d6b3..fe6a97465f8992281c909d4600bcfd8121d6a64a 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -137,6 +137,13 @@ class MatMulGradKernel : public framework::OpKernel { y_dims.push_back(1); } + int batch_count = 0; + // The first rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. + if (x_dims.size() > 3) { + batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, + std::multiplies()); + } // Fix the dOut dimensions. int M = 0, N = 0, batchCountX = 0, batchCountY = 0; @@ -149,7 +156,9 @@ class MatMulGradKernel : public framework::OpKernel { M = transpose_x ? x_dims[2] : x_dims[1]; break; default: - assert(false); + batchCountX = batch_count; + size_t mat_s = x_dims.size() - 2; + M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; } switch (y_dims.size()) { @@ -161,7 +170,9 @@ class MatMulGradKernel : public framework::OpKernel { N = transpose_y ? y_dims[1] : y_dims[2]; break; default: - assert(false); + batchCountY = batch_count; + size_t mat_s = y_dims.size() - 2; + N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; } if (batchCountX && batchCountY) { PADDLE_ENFORCE_EQ( @@ -172,7 +183,11 @@ class MatMulGradKernel : public framework::OpKernel { int batchCount = std::max(batchCountX, batchCountY); std::vector dout_dims = {M, N}; if (batchCount) { - dout_dims.insert(dout_dims.begin(), batchCount); + if (x_dims.size() > 3) { + dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); + } else { + dout_dims.insert(dout_dims.begin(), batchCount); + } } Tensor X = Reshape(x, make_ddim(x_dims)); Tensor Y = Reshape(y, make_ddim(y_dims)); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index cbd2cfec597622a362f9f09a46aa88a2f86e8733..78d162c09764a129375cdfc153db9f7fc69d08b2 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1794,8 +1794,9 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): def matmul(x, y, transpose_x=False, transpose_y=False, name=None): """ - Applies matrix multipication to two tensors. Currently only rank 1 to rank - 3 input tensors are supported. + Applies matrix multiplication to two tensors. Currently, the input + tensors' rank can be any, but when the rank of anyone inputs is + bigger than 3, this two inputs' rank should be equal. The actual behavior depends on the shapes of :math:`x`, :math:`y` and the flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically: @@ -1807,17 +1808,17 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): opposite: It is treated as :math:`[D, 1]` in nontransposed form and as :math:`[1, D]` in transposed form. - - After transpose, the two tensors are 2-D or 3-D and matrix multipication + - After transpose, the two tensors are 2-D or n-D and matrix multiplication performs in the following way. - If both are 2-D, they are multiplied like conventional matrices. - - If either is 3-D, it is treated as a stack of matrices residing in the - last two dimensions and a batched matrix multiply supporting broadcast + - If either is n-D, it is treated as a stack of matrices residing in the + last two dimensions and a batched matrix multiply supporting broadcast applies on the two tensors. - Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and - nontransposed, the prepended or appended dimension :math:`1` will be - removed after matrix multipication. + Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and + nontransposed, the prepended or appended dimension :math:`1` will be + removed after matrix multiplication. Args: x (Variable): The input variable which is a Tensor or LoDTensor. @@ -1834,6 +1835,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): .. code-block:: python # Examples to clarify shapes of the inputs and output + # x: [B, ..., M, K], y: [B, ..., K, N] + fluid.layers.matmul(x, y) # out: [B, ..., M, N] # x: [B, M, K], y: [B, K, N] fluid.layers.matmul(x, y) # out: [B, M, N] # x: [B, M, K], y: [K, N] @@ -1849,9 +1852,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): fluid.layers.matmul(x, y, True, True) # out: [M, N] """ helper = LayerHelper('matmul', **locals()) - assert max( - len(x.shape), len(y.shape) - ) <= 3, 'Currently only rank 1 to rank 3 input tensors are supported.' + assert max(len(x.shape), len(y.shape)) <= 3 or len(x.shape) == len( + y. + shape), 'Inputs\' rank should be equal or their rank should be less 4.' out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op( type='matmul', diff --git a/python/paddle/v2/fluid/tests/test_matmul_op.py b/python/paddle/v2/fluid/tests/test_matmul_op.py index f7dc4e053217dcceaf9c64e3605286fc0698593b..2dba25e17d3bf72563a44f629e94949302688b39 100644 --- a/python/paddle/v2/fluid/tests/test_matmul_op.py +++ b/python/paddle/v2/fluid/tests/test_matmul_op.py @@ -59,19 +59,18 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): X = X.reshape((X.size, 1)) elif X.ndim == 2: X = X.T - elif X.ndim == 3: - X = np.transpose(X, (0, 2, 1)) else: - raise ValueError('X must have between 1 and 3 dimensions') + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) if transpose_Y: if Y.ndim == 1: Y = Y.reshape((1, Y.size)) - elif Y.ndim == 2: - Y = Y.T - elif Y.ndim == 3: - Y = np.transpose(Y, (0, 2, 1)) else: - raise ValueError('Y must have between 1 and 3 dimensions') + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + Out = np.matmul(X, Y) if not Out.shape: # We do not support 0-dimensional Tensors (scalars). So where @@ -120,13 +119,50 @@ for dim_X in [1, 2, 3]: dim_X, dim_Y, transpose_X, transpose_Y)) shape_X, shape_Y = generate_compatible_shapes( dim_X, dim_Y, transpose_X, transpose_Y) - test_class = type(test_name, (Generator, OpTest), { + globals()[test_name] = type(test_name, (Generator, OpTest), { 'shape_X': shape_X, 'shape_Y': shape_Y, 'transpose_X': transpose_X, 'transpose_Y': transpose_Y, }) - globals()[test_name] = test_class + + +# Test case n-dim +def generate_compatible_shapes(dim, transpose_X, transpose_Y): + M = 2 + N = 4 + K = 3 + shape_X = [2 for _ in range(dim - 2)] + shape_Y = [2 for _ in range(dim - 2)] + + if transpose_X: + shape_X += [K, M] + else: + shape_X += [M, K] + + if transpose_Y: + shape_Y += [N, K] + else: + shape_Y += [K, N] + + return shape_X, shape_Y + + +# Test case n-dim +for dim in [4]: + for transpose_X in [False, True]: + for transpose_Y in [False, True]: + test_name = ( + 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim, dim, transpose_X, transpose_Y)) + shape_X, shape_Y = generate_compatible_shapes(dim, transpose_X, + transpose_Y) + globals()[test_name] = type(test_name, (Generator, OpTest), { + 'shape_X': shape_X, + 'shape_Y': shape_Y, + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, + }) if __name__ == "__main__": unittest.main()