From 259858b41bce74e079503b33cf93ed4e48da9fdb Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 18 Jan 2018 20:42:07 +0800 Subject: [PATCH] modify doc --- paddle/operators/math/matmul.h | 1 + paddle/operators/matmul_op.cc | 8 +++++++- paddle/operators/matmul_op.h | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paddle/operators/math/matmul.h b/paddle/operators/math/matmul.h index 8a63d204cb..ca41201e12 100644 --- a/paddle/operators/math/matmul.h +++ b/paddle/operators/math/matmul.h @@ -49,6 +49,7 @@ class MatMulFunctor { "The dimensions of X and Y must be the same, and both of " "them should be %d-dimensional.", dim_b.size()); + // The previous Rank-2 dimensions are accumulated on the batch_count. for (int j = 0; j < dim_a.size() - 2; ++j) { PADDLE_ENFORCE(dim_b[j] == dim_a[j], "The dimensions of X[%d] and Y[%d] must be the same.", j, diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index 6ced0ef6c0..1707ed7e7d 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -49,6 +49,8 @@ class MatMulOp : public framework::OperatorWithKernel { "The dimensions of X and Y must be the same, and both of " "them should be %d-dimensional.", dim_x.size()); + + // The previous Rank-2 dimensions are accumulated on the batch_count. for (int j = 0; j < dim_x.size() - 2; ++j) { PADDLE_ENFORCE(dim_y[j] == dim_x[j], "The dimensions of X[%d] and Y[%d] must be the same.", j, @@ -185,10 +187,14 @@ Examples without transpose: - X: [B, M, K], Y: [K] => Out: [B, M] - X: [M, K], Y: [B, K, N] => Out: [B, M, N] - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] +- X: [B, ..., M, K], Y: [B, ..., K, N] => Out: [B, ..., M, N] The behavior is designed to be similar to the `numpy.matmul` function. The differences are: -- Currently only rank 1 to rank 3 input tensors are supported. +- When the rank of the input is greater than 3, the rank of X and + Y must be equal, and the former rank-2 dimensions are equal. +- When the rank of the input data is less than or equal to 3, it + is similar to the `numpy.matmul` function. - We add `transpose_X` and `transpose_Y` flags. Both the input `X` and `Y` can carry the LoD (Level of Details) information, diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 9f06791f7b..cf60234295 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -138,7 +138,7 @@ class MatMulGradKernel : public framework::OpKernel { } int batch_count = 0; - // + // The previous Rank-2 dimensions are accumulated on the batch_count. if (x_dims.size() > 3) { batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, std::multiplies()); -- GitLab