提交 782ddc5f 编写于 作者: C chengduoZH

follow comments

上级 cd38e2d1
......@@ -49,7 +49,7 @@ class MatMulFunctor {
"The dimensions of X and Y must be the same, and both of "
"them should be %d-dimensional.",
dim_b.size());
// The front rank-2 dimensions are accumulated on the batch_count, and the
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_b[j], dim_a[j],
......
......@@ -51,7 +51,7 @@ class MatMulOp : public framework::OperatorWithKernel {
"them should be %d-dimensional.",
dim_x.size());
// The front rank-2 dimensions are accumulated on the batch_count, and the
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
......@@ -196,7 +196,7 @@ The differences are:
- When the rank of the input data is less than or equal to 3, it
is similar to the `numpy.matmul` function.
- When the rank of the input is greater than 3, the rank of X and
Y must be equal, and the front `rank - 2` dimensions must be equal.
Y must be equal, and the first `rank - 2` dimensions must be equal.
- We add `transpose_X` and `transpose_Y` flags.
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
......
......@@ -138,7 +138,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
}
int batch_count = 0;
// The front rank-2 dimensions are accumulated on the batch_count, and the
// The first rank-2 dimensions are accumulated on the batch_count, and the
// last two dimensions are used for matrix multiplication.
if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
......
......@@ -127,6 +127,7 @@ for dim_X in [1, 2, 3]:
})
# Test case n-dim
def generate_compatible_shapes(dim, transpose_X, transpose_Y):
M = 2
N = 4
......@@ -135,14 +136,14 @@ def generate_compatible_shapes(dim, transpose_X, transpose_Y):
shape_Y = [2 for _ in range(dim - 2)]
if transpose_X:
shape_X = shape_X + [K, M]
shape_X += [K, M]
else:
shape_X = shape_X + [M, K]
shape_X += [M, K]
if transpose_Y:
shape_Y = shape_Y + [N, K]
shape_Y += [N, K]
else:
shape_Y = shape_Y + [K, N]
shape_Y += [K, N]
return shape_X, shape_Y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册