未验证 提交 b76a6dee 编写于 作者: W wangchaochaohu 提交者: GitHub

Cherry Pick Fix error message (#20164)

* fix the error message for reduce_mean and reduce_sum op (#20063)

* fix the error message for reduce_mean and reduce_sum op test=develop

* fix typo test=develop

* fix according review advice test=develop

* fix the test test=develop

* fix test=develop

* Fill constant error message fix (#20075)

* fix the constant error message test=develop

* fix typo test=develop

* fix typo test=develop

* fix code style test=develop

* fix comment and bugs test=develop

* fix the bug test=develop

* fix and add unittest test=develop

* fix the typo test=develop

* add support for the fill_constant op test=develop

* add test for ci coverage test=develop
上级 907a853d
......@@ -87,4 +87,5 @@ REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>);
......@@ -165,13 +165,20 @@ class ReduceOp : public framework::OperatorWithKernel {
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
PADDLE_ENFORCE_LE(x_rank, 6,
"ShapeError: The input tensor X's dimensions of Reduce "
"should be less equal than 6. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims);
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], x_rank,
"ShapeError: The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)]."
"which dimesion = %d, But received dim index = %d",
i, x_rank, dims[i]);
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
......@@ -202,7 +209,7 @@ class ReduceOp : public framework::OperatorWithKernel {
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (dims[0] != 0) {
if (dims.size() > 0 && dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -223,10 +230,12 @@ class ReduceGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
for (size_t i = 0; i < dims.size(); ++i) {
PADDLE_ENFORCE_LT(dims[i], x_rank,
"ShapeError: The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)]."
"which dimesion = %d, But received dim index = %d",
i, x_rank, dims[i]);
if (dims[i] < 0) dims[i] = x_rank + dims[i];
PADDLE_ENFORCE_LT(
dims[i], x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
}
sort(dims.begin(), dims.end());
auto x_grad_name = framework::GradVarName("X");
......
......@@ -27,7 +27,19 @@ __all__ = ['DataFeeder']
def convert_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
if isinstance(dtype, str):
if dtype in [
'float32', 'int64', 'float64', 'float16', 'int32', 'uint8',
'bool'
]:
return dtype
else:
raise ValueError(
"dtype must be any of [bool, int32, float32, int64, "
"float64, uint8]")
elif dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
......@@ -40,7 +52,7 @@ def convert_dtype(dtype):
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
else:
raise ValueError("dtype must be any of [int32, float32, int64, "
raise ValueError("dtype must be any of [bool,int32, float32, int64, "
"float64, uint8]")
......
......@@ -5611,6 +5611,15 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
"""
helper = LayerHelper('reduce_sum', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in reduce_sum must be Variable, but received %s"
% (type(input)))
if convert_dtype(
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in reduce_sum must be float32 or float64 or int32 or int64, but received %s."
% (convert_dtype(input.dtype)))
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
......@@ -5670,6 +5679,15 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_mean(y, dim=[0, 1]) # [4.0, 5.0]
"""
helper = LayerHelper('reduce_mean', **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in reduce_mean must be Variable, but received %s"
% (type(input)))
if convert_dtype(
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in reduce_mean must be float32 or float64 or int32 or int64, but received %s."
% (convert_dtype(input.dtype)))
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
......
......@@ -21,6 +21,7 @@ from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc
from .layer_function_generator import templatedoc
from ..data_feeder import convert_dtype
import numpy
__all__ = [
......@@ -397,8 +398,21 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"""
helper = LayerHelper("fill_constant", **locals())
if convert_dtype(dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The create data type in fill_constant must be one of 'bool', float16, float32,"
"float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
if not (convert_dtype(dtype) == convert_dtype(out.dtype)):
raise TypeError(
"The create data type in op must be same with out type"
"but received %s and out dtype %s." % (convert_dtype(
(dtype), convert_dtype(out.dtype))))
helper.append_op(
type='fill_constant',
inputs={},
......
......@@ -183,7 +183,7 @@ class TestIfElse(unittest.TestCase):
false_target = fluid.layers.tanh(false_target)
ie.output(false_target)
if_out = ie()
out = layers.reduce_sum(if_out)
out = layers.reduce_sum(if_out[0])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......
......@@ -20,6 +20,8 @@ from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestFillConstantOp1(OpTest):
......@@ -104,5 +106,41 @@ class TestFillConstantOpWithSelectedRows(OpTest):
self.check_with_place(place)
class TestFillConstantOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
self.assertRaises(
ValueError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint4')
self.assertRaises(
ValueError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='int16',
out=x1)
# The input dtype of fill_constant must be one of bool, float16,
#float32, float64, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint8')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='float64',
out=x2)
if __name__ == "__main__":
unittest.main()
......@@ -17,6 +17,9 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestSumOp(OpTest):
......@@ -411,5 +414,29 @@ class Test1DReduceWithAxes1(OpTest):
self.check_grad(['X'], 'Out')
class TestReduceSumOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_sum_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.reduce_sum, x1)
# The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)
class TestReduceMeanOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_mean_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.reduce_mean, x1)
# The input dtype of reduce_mean_op must be float32 or float64 or int32 or int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.reduce_mean, x2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册