提交 a4753f3a 编写于 作者: 石晓伟 提交者: Tao Luo

Optimize error message of mean_op and matmul_op (#20413)

* add data type check, test=develop

* polish error messages, test=develop

* polish error messages, test=develop

* Remove support for the CPU architecture matmul, test=develop

* fix syntax bug, test=develop
上级 d6c1d6ca
......@@ -21,6 +21,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
/**
* Printing shape information into a string is easy to use.
*/
inline static std::string DumpMatrixShape(const math::MatDescriptor &desc) {
std::stringstream buffer;
buffer << "[" << desc.batch_size_ << ", " << desc.height_ << ", "
<< desc.width_ << "]";
return buffer.str();
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
......@@ -303,21 +314,37 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<bool>("transpose_Y"));
if (context->IsRuntime()) {
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
PADDLE_ENFORCE(
mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0,
"ShapeError: The batch size of the two matrices should be equal, or "
"at least one is zero.\n"
"But received X's shape: %s, Y's shape: %s.",
DumpMatrixShape(mat_dim_x).c_str(),
DumpMatrixShape(mat_dim_y).c_str());
}
std::vector<int64_t> dim_out;
int64_t dim_out_y = mat_dim_y.width_;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
int head_number = context->Attrs().Get<int>("head_number");
bool split_vertical_y = (mat_dim_x.width_ != mat_dim_y.height_);
PADDLE_ENFORCE_LE(head_number, mat_dim_x.width_);
PADDLE_ENFORCE_LE(
head_number, mat_dim_x.width_,
"ShapeError: Unsatisfied mkl acceleration library requirements: "
"The number of heads "
"(%d) must be equal to X's width. But received X's shape: %s.",
head_number, DumpMatrixShape(mat_dim_x).c_str());
if (!split_vertical_y && head_number > 0) {
dim_out_y = head_number * mat_dim_y.width_;
}
#else
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
PADDLE_ENFORCE_EQ(
mat_dim_x.width_, mat_dim_y.height_,
"ShapeError: Input X's width should be equal to the Y's height, "
"but received X's shape: %s,"
"Y's shape: %s.",
DumpMatrixShape(mat_dim_x).c_str(), DumpMatrixShape(mat_dim_y).c_str());
#endif
if (mat_dim_x.batch_size_ != 0) {
......@@ -461,15 +488,11 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
REGISTER_OP_CPU_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
matmul_grad,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
......
......@@ -6873,6 +6873,22 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
"""
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
if not isinstance(val, Variable):
raise TypeError(
"The type of %s in matmul must be Variable, but received %s.\n"
% (name, (type(val))))
if convert_dtype(val.dtype) in ['float16']:
warnings.warn(
"The data type of %s in matmul only support float16 in GPU now."
% name)
if convert_dtype(
val.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of %s in matmul must be float16 or float32 or float64, but received %s.\n"
% (name, (convert_dtype(val.dtype))))
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
......@@ -6886,8 +6902,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if transpose_y:
y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
if x_shape[-1] != y_shape[-2]:
raise ValueError("Invalid inputs for matmul. x: %s, y: %s\n" %
(x_shape, y_shape))
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: %s, Y's shape: %s\n" %
(x_shape, y_shape))
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
......@@ -6895,8 +6914,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" %
(x.shape, y.shape))
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape))
__check_input(x, y)
......@@ -14711,6 +14733,20 @@ def mean(x, name=None):
helper = LayerHelper("mean", **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in mean must be Variable, but received %s.\n" %
(type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mean only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in mean must be float16 or float32 or float64, but received %s.\n"
% (convert_dtype(x.dtype)))
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y):
......@@ -112,6 +114,21 @@ class Generator(object):
['X'], 'Out', max_relative_error=1e-3, no_grad_set=set('Y'))
class TestMatmulOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The inputs type of matmul_op must be Variable.
input1 = 12
self.assertRaises(TypeError, fluid.layers.matmul, input1, input1)
# The inputs dtype of matmul_op must be float32, float64.
input2 = fluid.layers.data(
name='input2', shape=[10, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.matmul, input2, input2)
input3 = fluid.layers.data(
name='input3', shape=[2, 2], dtype="float16")
fluid.layers.matmul(input3, input3)
# Generate test cases for all possibilities
def inject_test(dim_x, dim_y, trans_x, trans_y):
test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
......
......@@ -18,6 +18,8 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestMeanOp(OpTest):
......@@ -38,6 +40,21 @@ class TestMeanOp(OpTest):
self.check_grad(['X'], 'Out')
class TestMeanOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of mean_op must be Variable.
input1 = 12
self.assertRaises(TypeError, fluid.layers.mean, input1)
# The input dtype of mean_op must be float16, float32, float64.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.mean, input2)
input3 = fluid.layers.data(
name='input3', shape=[4], dtype="float16")
fluid.layers.softmax(input3)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16MeanOp(TestMeanOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册