From 0468422d06eb2094824141002dfb63f0dc03c513 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 19 Jan 2018 19:32:24 +0800 Subject: [PATCH] follow comments --- paddle/operators/math/matmul.h | 11 ++++++----- paddle/operators/matmul_op.cc | 22 ++++++++++++---------- paddle/operators/matmul_op.h | 3 ++- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h index ca41201e12..88341412fb 100644 --- a/paddle/operators/math/matmul.h +++ b/paddle/operators/math/matmul.h @@ -45,15 +45,16 @@ class MatMulFunctor { std::vector out_dim; int64_t batch_count = 1; if (dim_a.size() > 3) { - PADDLE_ENFORCE(dim_b.size() > 3, + PADDLE_ENFORCE(dim_b.size() == dim_a.size(), "The dimensions of X and Y must be the same, and both of " "them should be %d-dimensional.", dim_b.size()); - // The previous Rank-2 dimensions are accumulated on the batch_count. + // The front rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. for (int j = 0; j < dim_a.size() - 2; ++j) { - PADDLE_ENFORCE(dim_b[j] == dim_a[j], - "The dimensions of X[%d] and Y[%d] must be the same.", j, - j); + PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j], + "The %d-th dimension of X and Y must be the same.", + j); out_dim.push_back(dim_a[j]); batch_count *= dim_a[j]; } diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index 1707ed7e7d..d395dfd81b 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -45,16 +45,18 @@ class MatMulOp : public framework::OperatorWithKernel { std::vector out_dim; int64_t batch_count = 1; if (dim_x.size() > 3) { - PADDLE_ENFORCE(dim_y.size() == dim_x.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_x.size()); + PADDLE_ENFORCE_EQ( + dim_y.size(), dim_x.size(), + "The dimensions of X and Y must be the same, and both of " + "them should be %d-dimensional.", + dim_x.size()); - // The previous Rank-2 dimensions are accumulated on the batch_count. + // The front rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. for (int j = 0; j < dim_x.size() - 2; ++j) { - PADDLE_ENFORCE(dim_y[j] == dim_x[j], - "The dimensions of X[%d] and Y[%d] must be the same.", j, - j); + PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], + "The %d-th dimension of X and Y must be the same.", + j); out_dim.push_back(dim_x[j]); batch_count *= dim_x[j]; } @@ -191,10 +193,10 @@ Examples without transpose: The behavior is designed to be similar to the `numpy.matmul` function. The differences are: -- When the rank of the input is greater than 3, the rank of X and - Y must be equal, and the former rank-2 dimensions are equal. - When the rank of the input data is less than or equal to 3, it is similar to the `numpy.matmul` function. +- When the rank of the input is greater than 3, the rank of X and + Y must be equal, and the front `rank - 2` dimensions must be equal. - We add `transpose_X` and `transpose_Y` flags. Both the input `X` and `Y` can carry the LoD (Level of Details) information, diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index cf60234295..4935db839c 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -138,7 +138,8 @@ class MatMulGradKernel : public framework::OpKernel { } int batch_count = 0; - // The previous Rank-2 dimensions are accumulated on the batch_count. + // The front rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. if (x_dims.size() > 3) { batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, std::multiplies()); -- GitLab