提交 0468422d 编写于 作者: C chengduoZH

follow comments

上级 95b896ce
...@@ -45,15 +45,16 @@ class MatMulFunctor { ...@@ -45,15 +45,16 @@ class MatMulFunctor {
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
int64_t batch_count = 1; int64_t batch_count = 1;
if (dim_a.size() > 3) { if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() > 3, PADDLE_ENFORCE(dim_b.size() == dim_a.size(),
"The dimensions of X and Y must be the same, and both of " "The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.", "them should be %d-dimensional.",
dim_b.size()); dim_b.size());
// The previous Rank-2 dimensions are accumulated on the batch_count. // The front rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_a.size() - 2; ++j) { for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE(dim_b[j] == dim_a[j], PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j],
"The dimensions of X[%d] and Y[%d] must be the same.", j, "The %d-th dimension of X and Y must be the same.",
j); j);
out_dim.push_back(dim_a[j]); out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j]; batch_count *= dim_a[j];
} }
......
...@@ -45,16 +45,18 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -45,16 +45,18 @@ class MatMulOp : public framework::OperatorWithKernel {
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
int64_t batch_count = 1; int64_t batch_count = 1;
if (dim_x.size() > 3) { if (dim_x.size() > 3) {
PADDLE_ENFORCE(dim_y.size() == dim_x.size(), PADDLE_ENFORCE_EQ(
"The dimensions of X and Y must be the same, and both of " dim_y.size(), dim_x.size(),
"them should be %d-dimensional.", "The dimensions of X and Y must be the same, and both of "
dim_x.size()); "them should be %d-dimensional.",
dim_x.size());
// The previous Rank-2 dimensions are accumulated on the batch_count. // The front rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_x.size() - 2; ++j) { for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE(dim_y[j] == dim_x[j], PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
"The dimensions of X[%d] and Y[%d] must be the same.", j, "The %d-th dimension of X and Y must be the same.",
j); j);
out_dim.push_back(dim_x[j]); out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j]; batch_count *= dim_x[j];
} }
...@@ -191,10 +193,10 @@ Examples without transpose: ...@@ -191,10 +193,10 @@ Examples without transpose:
The behavior is designed to be similar to the `numpy.matmul` function. The behavior is designed to be similar to the `numpy.matmul` function.
The differences are: The differences are:
- When the rank of the input is greater than 3, the rank of X and
Y must be equal, and the former rank-2 dimensions are equal.
- When the rank of the input data is less than or equal to 3, it - When the rank of the input data is less than or equal to 3, it
is similar to the `numpy.matmul` function. is similar to the `numpy.matmul` function.
- When the rank of the input is greater than 3, the rank of X and
Y must be equal, and the front `rank - 2` dimensions must be equal.
- We add `transpose_X` and `transpose_Y` flags. - We add `transpose_X` and `transpose_Y` flags.
Both the input `X` and `Y` can carry the LoD (Level of Details) information, Both the input `X` and `Y` can carry the LoD (Level of Details) information,
......
...@@ -138,7 +138,8 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -138,7 +138,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
int batch_count = 0; int batch_count = 0;
// The previous Rank-2 dimensions are accumulated on the batch_count. // The front rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
if (x_dims.size() > 3) { if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>()); std::multiplies<int>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册