提交 578d60bf 编写于 作者: C chengduoZH

code refine

上级 2edc136c
......@@ -41,22 +41,18 @@ class MatMulFunctor {
"Input tensor a must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_b.size(), 1,
"Input tensor b must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_a.size(), 4,
"Input tensor a must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_b.size(), 4,
"Input tensor b must be at most 4-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_a.size() > 3) {
PADDLE_ENFORCE(dim_b.size() > 3,
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
"them should be %d-dimensional.",
dim_b.size());
for (int j = 0; j < dim_a.size() - 2; ++j) {
PADDLE_ENFORCE(
dim_b[j] == dim_a[j],
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
PADDLE_ENFORCE(dim_b[j] == dim_a[j],
"The dimensions of X[%d] and Y[%d] must be the same.", j,
j);
out_dim.push_back(dim_a[j]);
batch_count *= dim_a[j];
}
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/matmul_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -42,22 +41,18 @@ class MatMulOp : public framework::OperatorWithKernel {
"Input tensor X must be at least 1-dimensional.");
PADDLE_ENFORCE_GE(dim_y.size(), 1,
"Input tensor Y must be at least 1-dimensional.");
PADDLE_ENFORCE_LE(dim_x.size(), 4,
"Input tensor X must be at most 4-dimensional.");
PADDLE_ENFORCE_LE(dim_y.size(), 4,
"Input tensor Y must be at most 4-dimensional.");
std::vector<int64_t> out_dim;
int64_t batch_count = 1;
if (dim_x.size() > 3) {
PADDLE_ENFORCE(dim_y.size() == dim_x.size(),
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
"them should be %d-dimensional.",
dim_x.size());
for (int j = 0; j < dim_x.size() - 2; ++j) {
PADDLE_ENFORCE(
dim_y[j] == dim_x[j],
"The dimensions of X and Y must be the same, and both of "
"them should be 4-dimensional.");
PADDLE_ENFORCE(dim_y[j] == dim_x[j],
"The dimensions of X[%d] and Y[%d] must be the same.", j,
j);
out_dim.push_back(dim_x[j]);
batch_count *= dim_x[j];
}
......
......@@ -137,6 +137,12 @@ class MatMulGradKernel : public framework::OpKernel<T> {
y_dims.push_back(1);
}
int batch_count = 0;
//
if (x_dims.size() > 3) {
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>());
}
// Fix the dOut dimensions.
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
......@@ -149,8 +155,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
M = transpose_x ? x_dims[2] : x_dims[1];
break;
default:
batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
std::multiplies<int>());
batchCountX = batch_count;
size_t mat_s = x_dims.size() - 2;
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
}
......@@ -164,8 +169,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
N = transpose_y ? y_dims[1] : y_dims[2];
break;
default:
batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1,
std::multiplies<int>());
batchCountY = batch_count;
size_t mat_s = y_dims.size() - 2;
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1];
}
......@@ -180,8 +184,6 @@ class MatMulGradKernel : public framework::OpKernel<T> {
if (batchCount) {
if (x_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
} else if (y_dims.size() > 3) {
dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2);
} else {
dout_dims.insert(dout_dims.begin(), batchCount);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册