diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h index 98a98e4b994bad5efb361505c08cbb53b78e2b08..8a63d204cb0647f19631eb3b59205df1d45bb7c4 100644 --- a/paddle/operators/math/matmul.h +++ b/paddle/operators/math/matmul.h @@ -41,22 +41,18 @@ class MatMulFunctor { "Input tensor a must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_b.size(), 1, "Input tensor b must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_a.size(), 4, - "Input tensor a must be at most 4-dimensional."); - PADDLE_ENFORCE_LE(dim_b.size(), 4, - "Input tensor b must be at most 4-dimensional."); std::vector out_dim; int64_t batch_count = 1; if (dim_a.size() > 3) { PADDLE_ENFORCE(dim_b.size() > 3, "The dimensions of X and Y must be the same, and both of " - "them should be 4-dimensional."); + "them should be %d-dimensional.", + dim_b.size()); for (int j = 0; j < dim_a.size() - 2; ++j) { - PADDLE_ENFORCE( - dim_b[j] == dim_a[j], - "The dimensions of X and Y must be the same, and both of " - "them should be 4-dimensional."); + PADDLE_ENFORCE(dim_b[j] == dim_a[j], + "The dimensions of X[%d] and Y[%d] must be the same.", j, + j); out_dim.push_back(dim_a[j]); batch_count *= dim_a[j]; } diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index 155346db41d8310e6095642db9f34a3b283a041a..6ced0ef6c0a0648fe9b8ecb2d8034a6c043e35e5 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/matmul_op.h" -#include namespace paddle { namespace operators { @@ -42,22 +41,18 @@ class MatMulOp : public framework::OperatorWithKernel { "Input tensor X must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_y.size(), 1, "Input tensor Y must be at least 1-dimensional."); - PADDLE_ENFORCE_LE(dim_x.size(), 4, - "Input tensor X must be at most 4-dimensional."); - PADDLE_ENFORCE_LE(dim_y.size(), 4, - "Input tensor Y must be at most 4-dimensional."); std::vector out_dim; int64_t batch_count = 1; if (dim_x.size() > 3) { PADDLE_ENFORCE(dim_y.size() == dim_x.size(), "The dimensions of X and Y must be the same, and both of " - "them should be 4-dimensional."); + "them should be %d-dimensional.", + dim_x.size()); for (int j = 0; j < dim_x.size() - 2; ++j) { - PADDLE_ENFORCE( - dim_y[j] == dim_x[j], - "The dimensions of X and Y must be the same, and both of " - "them should be 4-dimensional."); + PADDLE_ENFORCE(dim_y[j] == dim_x[j], + "The dimensions of X[%d] and Y[%d] must be the same.", j, + j); out_dim.push_back(dim_x[j]); batch_count *= dim_x[j]; } diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 11266db4b9a7e31b42c56cc4a9ea7175f555786a..9f06791f7b16f2d08473fc37e1a11c94b7452a8d 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -137,6 +137,12 @@ class MatMulGradKernel : public framework::OpKernel { y_dims.push_back(1); } + int batch_count = 0; + // + if (x_dims.size() > 3) { + batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, + std::multiplies()); + } // Fix the dOut dimensions. int M = 0, N = 0, batchCountX = 0, batchCountY = 0; @@ -149,8 +155,7 @@ class MatMulGradKernel : public framework::OpKernel { M = transpose_x ? x_dims[2] : x_dims[1]; break; default: - batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1, - std::multiplies()); + batchCountX = batch_count; size_t mat_s = x_dims.size() - 2; M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; } @@ -164,8 +169,7 @@ class MatMulGradKernel : public framework::OpKernel { N = transpose_y ? y_dims[1] : y_dims[2]; break; default: - batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1, - std::multiplies()); + batchCountY = batch_count; size_t mat_s = y_dims.size() - 2; N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; } @@ -180,8 +184,6 @@ class MatMulGradKernel : public framework::OpKernel { if (batchCount) { if (x_dims.size() > 3) { dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); - } else if (y_dims.size() > 3) { - dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2); } else { dout_dims.insert(dout_dims.begin(), batchCount); }