未验证 提交 0fb18bc2 编写于 作者: S ShenLiang 提交者: GitHub

enforce the matmul_v2 error message (#29297)

上级 9b59a589
......@@ -71,8 +71,14 @@ static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims,
for (int i = 0; i < ndim; ++i) {
PADDLE_ENFORCE_EQ(
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1,
true, platform::errors::InvalidArgument(
"Input(X) and Input(Y) has error dim."));
true,
platform::errors::InvalidArgument(
"Input(X) and Input(Y) has error dim."
"X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s],"
"or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1,"
"But received X_broadcast's shape[%s] = [%s]"
"received Y_broadcast's shape[%s] = [%s]",
i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i]));
if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) {
out_bd_dims[i] = 0;
} else {
......@@ -118,10 +124,13 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const T* y_data = Y->data<T>();
if (x_ndim == 1 && y_ndim == 1) {
PADDLE_ENFORCE_EQ(X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers is not equal to Y's numbers,"
"when X/Y's dims =1"));
PADDLE_ENFORCE_EQ(
X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
X->numel(), Y->numel()));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>(ctx.GetPlace());
......@@ -140,13 +149,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if (x_ndim == 1) {
const int N = X->numel();
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2]));
}
std::vector<std::int64_t> out_dims(y_ndim - 1);
if (trans_y) {
......@@ -182,13 +197,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if (y_ndim == 1) {
const int N = Y->numel();
if (trans_x) {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 2], N,
platform::errors::InvalidArgument("Input(X) has error dim."));
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N,
platform::errors::InvalidArgument(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2]));
} else {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 1], N,
platform::errors::InvalidArgument("Input(X) has error dim."));
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N,
platform::errors::InvalidArgument(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1]));
}
std::vector<std::int64_t> out_dims(x_ndim - 1);
if (trans_x) {
......@@ -225,11 +246,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument(
"Input(X) has error dim."));
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument(
"Input(X) has error dim."));
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
}
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
const int ndim = (std::max)(x_ndim, y_ndim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册