提交 2edc136c 编写于 作者: C chengduoZH

add 4-d for matmul_op

上级 4b3e22b8
...@@ -41,10 +41,26 @@ class MatMulFunctor { ...@@ -41,10 +41,26 @@ class MatMulFunctor {
"Input tensor a must be at least 1-dimensional."); "Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1, PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional."); "Input tensor b must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_a.size(), 3, PADDLE_ENFORCE_LE(dim_a.size(), 4,
"Input tensor a must be at most 3-dimensional."); "Input tensor a must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_b.size(), 3, PADDLE_ENFORCE_LE(dim_b.size(), 4,
"Input tensor b must be at most 3-dimensional."); "Input tensor b must be at most 4-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() > 3,
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE(
dim_b[j] == dim_a[j],
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j];
}
}
int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0, int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0,
strideA = 0, strideB = 0; strideA = 0, strideB = 0;
...@@ -67,7 +83,11 @@ class MatMulFunctor { ...@@ -67,7 +83,11 @@ class MatMulFunctor {
strideA = M * kA; strideA = M * kA;
break; break;
default: default:
assert(false); batchCountA = batch_count;
size_t mat_s = dim_a.size() - 2;
M = trans_a ? dim_a[mat_s + 1] : dim_a[mat_s];
kA = trans_a ? dim_a[mat_s] : dim_a[mat_s + 1];
strideA = M * kA;
} }
switch (dim_b.size()) { switch (dim_b.size()) {
...@@ -88,7 +108,11 @@ class MatMulFunctor { ...@@ -88,7 +108,11 @@ class MatMulFunctor {
strideB = kB * N; strideB = kB * N;
break; break;
default: default:
assert(false); batchCountB = batch_count;
size_t mat_s = dim_b.size() - 2;
kB = trans_b ? dim_b[mat_s + 1] : dim_b[mat_s];
N = trans_b ? dim_b[mat_s] : dim_b[mat_s + 1];
strideB = kB * N;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/matmul_op.h" #include "paddle/operators/matmul_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,10 +42,26 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -41,10 +42,26 @@ class MatMulOp : public framework::OperatorWithKernel {
"Input tensor X must be at least 1-dimensional."); "Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1, PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional."); "Input tensor Y must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_x.size(), 3, PADDLE_ENFORCE_LE(dim_x.size(), 4,
"Input tensor X must be at most 3-dimensional."); "Input tensor X must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_y.size(), 3, PADDLE_ENFORCE_LE(dim_y.size(), 4,
"Input tensor Y must be at most 3-dimensional."); "Input tensor Y must be at most 4-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_x.size() > 3) {
PADDLE_ENFORCE(dim_y.size() == dim_x.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE(
dim_y[j] == dim_x[j],
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j];
}
}
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
bool remove_initial_dim = false, remove_final_dim = false; bool remove_initial_dim = false, remove_final_dim = false;
...@@ -70,7 +87,11 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -70,7 +87,11 @@ class MatMulOp : public framework::OperatorWithKernel {
KX = transpose_x ? dim_x[1] : dim_x[2]; KX = transpose_x ? dim_x[1] : dim_x[2];
break; break;
default: default:
assert(false); batchCountX = batch_count;
size_t mat_s = dim_x.size() - 2;
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
break;
} }
switch (dim_y.size()) { switch (dim_y.size()) {
...@@ -94,7 +115,10 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -94,7 +115,10 @@ class MatMulOp : public framework::OperatorWithKernel {
N = transpose_y ? dim_y[1] : dim_y[2]; N = transpose_y ? dim_y[1] : dim_y[2];
break; break;
default: default:
assert(false); batchCountY = batch_count;
size_t mat_s = dim_y.size() - 2;
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -110,7 +134,11 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -110,7 +134,11 @@ class MatMulOp : public framework::OperatorWithKernel {
std::vector<int64_t> dim_out; std::vector<int64_t> dim_out;
if (batchCount) { if (batchCount) {
dim_out.push_back(batchCount); if (dim_x.size() > 3) {
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
} else {
dim_out.push_back(batchCount);
}
} }
if (!remove_initial_dim) { if (!remove_initial_dim) {
dim_out.push_back(M); dim_out.push_back(M);
......
...@@ -149,7 +149,10 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -149,7 +149,10 @@ class MatMulGradKernel : public framework::OpKernel<T> {
M = transpose_x ? x_dims[2] : x_dims[1]; M = transpose_x ? x_dims[2] : x_dims[1];
break; break;
default: default:
assert(false); batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>());
size_t mat_s = x_dims.size() - 2;
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
} }
switch (y_dims.size()) { switch (y_dims.size()) {
...@@ -161,7 +164,10 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -161,7 +164,10 @@ class MatMulGradKernel : public framework::OpKernel<T> {
N = transpose_y ? y_dims[1] : y_dims[2]; N = transpose_y ? y_dims[1] : y_dims[2];
break; break;
default: default:
assert(false); batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1,
std::multiplies<int>());
size_t mat_s = y_dims.size() - 2;
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1];
} }
if (batchCountX && batchCountY) { if (batchCountX && batchCountY) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -172,7 +178,13 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -172,7 +178,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
int batchCount = std::max(batchCountX, batchCountY); int batchCount = std::max(batchCountX, batchCountY);
std::vector<int64_t> dout_dims = {M, N}; std::vector<int64_t> dout_dims = {M, N};
if (batchCount) { if (batchCount) {
dout_dims.insert(dout_dims.begin(), batchCount); if (x_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
} else if (y_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2);
} else {
dout_dims.insert(dout_dims.begin(), batchCount);
}
} }
Tensor X = Reshape<T>(x, make_ddim(x_dims)); Tensor X = Reshape<T>(x, make_ddim(x_dims));
Tensor Y = Reshape<T>(y, make_ddim(y_dims)); Tensor Y = Reshape<T>(y, make_ddim(y_dims));
......
...@@ -128,5 +128,22 @@ for dim_X in [1, 2, 3]: ...@@ -128,5 +128,22 @@ for dim_X in [1, 2, 3]:
}) })
globals()[test_name] = test_class globals()[test_name] = test_class
# Test case 4-dim
dim_X = 4
dim_Y = 4
transpose_X = False
transpose_Y = False
test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
dim_X, dim_Y, transpose_X, transpose_Y))
shape_X = [2, 2, 2, 3]
shape_Y = [2, 2, 3, 4]
test_class = type(test_name, (Generator, OpTest), {
'shape_X': shape_X,
'shape_Y': shape_Y,
'transpose_X': transpose_X,
'transpose_Y': transpose_Y,
})
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册