提交 578d60bf 编写于 作者: C chengduoZH

code refine

上级 2edc136c
...@@ -41,22 +41,18 @@ class MatMulFunctor { ...@@ -41,22 +41,18 @@ class MatMulFunctor {
"Input tensor a must be at least 1-dimensional."); "Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1, PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional."); "Input tensor b must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_a.size(), 4,
"Input tensor a must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_b.size(), 4,
"Input tensor b must be at most 4-dimensional.");
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
int64_t batch_count = 1; int64_t batch_count = 1;
if (dim_a.size() > 3) { if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() > 3, PADDLE_ENFORCE(dim_b.size() > 3,
"The dimensions of X and Y must be the same, and both of " "The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional."); "them should be %d-dimensional.",
dim_b.size());
for (int j = 0; j < dim_a.size() - 2; ++j) { for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_b[j] == dim_a[j],
dim_b[j] == dim_a[j], "The dimensions of X[%d] and Y[%d] must be the same.", j,
"The dimensions of X and Y must be the same, and both of " j);
"them should be 4-dimensional.");
out_dim.push_back(dim_a[j]); out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j]; batch_count *= dim_a[j];
} }
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/matmul_op.h" #include "paddle/operators/matmul_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,22 +41,18 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -42,22 +41,18 @@ class MatMulOp : public framework::OperatorWithKernel {
"Input tensor X must be at least 1-dimensional."); "Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1, PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional."); "Input tensor Y must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_x.size(), 4,
"Input tensor X must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_y.size(), 4,
"Input tensor Y must be at most 4-dimensional.");
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
int64_t batch_count = 1; int64_t batch_count = 1;
if (dim_x.size() > 3) { if (dim_x.size() > 3) {
PADDLE_ENFORCE(dim_y.size() == dim_x.size(), PADDLE_ENFORCE(dim_y.size() == dim_x.size(),
"The dimensions of X and Y must be the same, and both of " "The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional."); "them should be %d-dimensional.",
dim_x.size());
for (int j = 0; j < dim_x.size() - 2; ++j) { for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_y[j] == dim_x[j],
dim_y[j] == dim_x[j], "The dimensions of X[%d] and Y[%d] must be the same.", j,
"The dimensions of X and Y must be the same, and both of " j);
"them should be 4-dimensional.");
out_dim.push_back(dim_x[j]); out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j]; batch_count *= dim_x[j];
} }
......
...@@ -137,6 +137,12 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -137,6 +137,12 @@ class MatMulGradKernel : public framework::OpKernel<T> {
y_dims.push_back(1); y_dims.push_back(1);
} }
int batch_count = 0;
//
if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>());
}
// Fix the dOut dimensions. // Fix the dOut dimensions.
int M = 0, N = 0, batchCountX = 0, batchCountY = 0; int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
...@@ -149,8 +155,7 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -149,8 +155,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
M = transpose_x ? x_dims[2] : x_dims[1]; M = transpose_x ? x_dims[2] : x_dims[1];
break; break;
default: default:
batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1, batchCountX = batch_count;
std::multiplies<int>());
size_t mat_s = x_dims.size() - 2; size_t mat_s = x_dims.size() - 2;
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
} }
...@@ -164,8 +169,7 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -164,8 +169,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
N = transpose_y ? y_dims[1] : y_dims[2]; N = transpose_y ? y_dims[1] : y_dims[2];
break; break;
default: default:
batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1, batchCountY = batch_count;
std::multiplies<int>());
size_t mat_s = y_dims.size() - 2; size_t mat_s = y_dims.size() - 2;
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1];
} }
...@@ -180,8 +184,6 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -180,8 +184,6 @@ class MatMulGradKernel : public framework::OpKernel<T> {
if (batchCount) { if (batchCount) {
if (x_dims.size() > 3) { if (x_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
} else if (y_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2);
} else { } else {
dout_dims.insert(dout_dims.begin(), batchCount); dout_dims.insert(dout_dims.begin(), batchCount);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册