未验证 提交 fee42441 编写于 作者: W wawltor 提交者: GitHub

just add the op error message for the matmul xpu (#30246)

 add the op error message for the matmul xpu 
上级 6bfdef72
......@@ -127,10 +127,18 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
PADDLE_ENFORCE_EQ(
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
platform::errors::InvalidArgument("Shape mistake in matmul_op, the "
"first tensor width must be same as "
"second tensor height, but received "
"width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>();
......@@ -251,12 +259,20 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
PADDLE_ENFORCE_EQ(
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the "
"first tensor width must be same as second tensor "
"height, but received "
"width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册