提交 782ddc5f 编写于 作者: C chengduoZH

follow comments

上级 cd38e2d1
...@@ -49,7 +49,7 @@ class MatMulFunctor { ...@@ -49,7 +49,7 @@ class MatMulFunctor {
"The dimensions of X and Y must be the same, and both of " "The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.", "them should be %d-dimensional.",
dim_b.size()); dim_b.size());
// The front rank-2 dimensions are accumulated on the batch_count, and the // The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication. // last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_a.size() - 2; ++j) { for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j], PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j],
......
...@@ -51,7 +51,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class MatMulOp : public framework::OperatorWithKernel {
"them should be %d-dimensional.", "them should be %d-dimensional.",
dim_x.size()); dim_x.size());
// The front rank-2 dimensions are accumulated on the batch_count, and the // The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication. // last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_x.size() - 2; ++j) { for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
...@@ -196,7 +196,7 @@ The differences are: ...@@ -196,7 +196,7 @@ The differences are:
- When the rank of the input data is less than or equal to 3, it - When the rank of the input data is less than or equal to 3, it
is similar to the `numpy.matmul` function. is similar to the `numpy.matmul` function.
- When the rank of the input is greater than 3, the rank of X and - When the rank of the input is greater than 3, the rank of X and
Y must be equal, and the front `rank - 2` dimensions must be equal. Y must be equal, and the first `rank - 2` dimensions must be equal.
- We add `transpose_X` and `transpose_Y` flags. - We add `transpose_X` and `transpose_Y` flags.
Both the input `X` and `Y` can carry the LoD (Level of Details) information, Both the input `X` and `Y` can carry the LoD (Level of Details) information,
......
...@@ -138,7 +138,7 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -138,7 +138,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
int batch_count = 0; int batch_count = 0;
// The front rank-2 dimensions are accumulated on the batch_count, and the // The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication. // last two dimensions are used for matrix multiplication.
if (x_dims.size() > 3) { if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
......
...@@ -127,6 +127,7 @@ for dim_X in [1, 2, 3]: ...@@ -127,6 +127,7 @@ for dim_X in [1, 2, 3]:
}) })
# Test case n-dim
def generate_compatible_shapes(dim, transpose_X, transpose_Y): def generate_compatible_shapes(dim, transpose_X, transpose_Y):
M = 2 M = 2
N = 4 N = 4
...@@ -135,14 +136,14 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y): ...@@ -135,14 +136,14 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y):
shape_Y = [2 for _ in range(dim - 2)] shape_Y = [2 for _ in range(dim - 2)]
if transpose_X: if transpose_X:
shape_X = shape_X + [K, M] shape_X += [K, M]
else: else:
shape_X = shape_X + [M, K] shape_X += [M, K]
if transpose_Y: if transpose_Y:
shape_Y = shape_Y + [N, K] shape_Y += [N, K]
else: else:
shape_Y = shape_Y + [K, N] shape_Y += [K, N]
return shape_X, shape_Y return shape_X, shape_Y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册