未验证 提交 fee42441 编写于 作者: W wawltor 提交者: GitHub

just add the op error message for the matmul xpu (#30246)

 add the op error message for the matmul xpu 
上级 6bfdef72
...@@ -127,10 +127,18 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -127,10 +127,18 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_, mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_op")); platform::errors::InvalidArgument("Shape mistake in matmul_op, the "
PADDLE_ENFORCE_EQ( "first tensor width must be same as "
mat_dim_a.batch_size_, mat_dim_b.batch_size_, "second tensor height, but received "
platform::errors::InvalidArgument("Shape mistake in matmul_op")); "width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha")); T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
...@@ -251,12 +259,20 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> { ...@@ -251,12 +259,20 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
} }
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_,
mat_dim_a.width_, mat_dim_b.height_, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); "Shape mistake in matmul_grad_op, the "
PADDLE_ENFORCE_EQ( "first tensor width must be same as second tensor "
mat_dim_a.batch_size_, mat_dim_b.batch_size_, "height, but received "
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); "width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha")); T alpha = static_cast<T>(context.Attr<float>("alpha"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册