未验证 提交 f26f7c36 编写于 作者: W wawltor 提交者: GitHub

Add some error meesage and dtyp, dtyep check for some ops (#23762)

Those ops include,scale, sum, sums,unique_with_counts,unique,
wherre, add error message and test case
上级 b822f74c
......@@ -28,16 +28,16 @@ class ScaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ScaleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScaleOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "scale");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "scale");
if (ctx->IsRuntime() && ctx->HasInput("ScaleTensor")) {
auto scale = ctx->Inputs("ScaleTensor");
PADDLE_ENFORCE_EQ(scale.size(), 1,
platform::errors::InvalidArgument(
"Input(ScaleTensor) size must be 1"));
"Input(ScaleTensor) size must be 1, "
"but received size is %d.",
scale.size()));
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
......
......@@ -32,11 +32,9 @@ class SumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
"Inputs(X) should not be null");
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "sum");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sum");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR_ARRAY) {
......@@ -48,11 +46,11 @@ class SumOp : public framework::OperatorWithKernel {
auto N = x_dims.size();
PADDLE_ENFORCE_GT(
N, 0,
"ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N, &x_dims);
N, 0, platform::errors::InvalidArgument(
"The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N, &x_dims));
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
......@@ -72,18 +70,21 @@ class SumOp : public framework::OperatorWithKernel {
in_dim = x_dim;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dim, x_dim,
"ShapeError: The input tensor X of SumOp must have same shape."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
PADDLE_ENFORCE_EQ(in_dim, x_dim,
platform::errors::InvalidArgument(
"The input tensor X of SumOp must"
" have same shape. But received X[0]'s shape = "
"[%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim));
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(), x_dim.size(),
"ShapeError: The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
platform::errors::InvalidArgument(
"The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s "
"shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim));
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
......@@ -91,10 +92,11 @@ class SumOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_EQ(
in_dim[j], x_dim[j],
"ShapeError: The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
platform::errors::InvalidArgument(
"The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim));
}
}
}
......@@ -115,9 +117,10 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1;
for (size_t idx = 0; idx < x_vars.size(); ++idx) {
PADDLE_ENFORCE_NOT_NULL(x_vars[idx],
"Input var[%s] should not be nullptr",
x_vars_name[idx]);
PADDLE_ENFORCE_NOT_NULL(
x_vars[idx],
platform::errors::NotFound("Input var[%s] should not be nullptr",
x_vars_name[idx]));
auto tensor =
framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
......@@ -126,11 +129,14 @@ class SumOp : public framework::OperatorWithKernel {
if (dtype == -1) {
dtype = tensor->type();
} else {
PADDLE_ENFORCE_EQ(dtype, tensor->type());
PADDLE_ENFORCE_EQ(dtype, tensor->type(),
platform::errors::InvalidArgument(
"The inputs type of sum op must be same"));
}
}
PADDLE_ENFORCE_NE(dtype, -1,
"Sum operator should have at least one tensor");
platform::errors::InvalidArgument(
"Sum operator should have at least one tensor"));
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
......
......@@ -93,7 +93,10 @@ void LodTensorArrayCompute(const framework::ExecutionContext &context) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(), true,
"Only support all inputs are TensorArray");
platform::errors::InvalidArgument(
"Only support all inputs are TensorArray, "
"but inputs[%d] is not TensorArray.",
i));
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
......@@ -106,7 +109,12 @@ void LodTensorArrayCompute(const framework::ExecutionContext &context) {
context.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE_EQ(out_array[i].lod(), in_array[i].lod());
PADDLE_ENFORCE_EQ(
out_array[i].lod(), in_array[i].lod(),
platform::errors::InvalidArgument(
"The lod message between inputs[%d] and"
" outputs[%d] must be same, but now is not same.",
i, i));
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(*context.template device_context<DeviceContext>()
......
......@@ -22,15 +22,16 @@ class UniqueOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UniqueOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniqueOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Index"),
"Output(Index) of UniqueOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "unique");
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 1, "Input(X) should be a vector.");
PADDLE_ENFORCE_EQ(
in_dims.size(), 1,
platform::errors::InvalidArgument("The Input(X) should be 1-D Tensor, "
"But now the dims of Input(X) is %d.",
in_dims.size()));
ctx->SetOutputDim("Out", {-1});
ctx->SetOutputDim("Index", in_dims);
......
......@@ -46,8 +46,12 @@ struct UniqueOpFunctor {
std::unordered_map<InT, int64_t> dict;
std::vector<InT> uniq;
PADDLE_ENFORCE(in_->numel() < pow(2, 31),
"numel of Unique op input should less than INT_MAX");
PADDLE_ENFORCE_LT(
in_->numel(), pow(2, 31),
platform::errors::InvalidArgument(
"The num of Input(X) elements should be less then INT_MAX, "
"but received num is %d.",
in_->numel()));
for (auto i = 0; i < in_->numel(); i++) {
auto it = dict.find(in_data[i]);
......@@ -71,13 +75,15 @@ struct UniqueOpFunctor {
const auto& index_type = index_->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, "
"but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
for (auto i = 0; i < in_->numel(); ++i) {
......
......@@ -22,19 +22,20 @@ class UniqueWithCountsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Index"),
"Output(Index) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Count"),
"Output(Count) of UniqueWithCountsOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique_with_counts");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"unique_with_counts");
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index",
"unique_with_counts");
OP_INOUT_CHECK(ctx->HasOutput("Count"), "Output", "Count",
"unique_with_counts");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 1,
"The op of fluid.layers.unique_with_counts, Input(X) should "
"be a vector.");
PADDLE_ENFORCE_EQ(
in_dims.size(), 1,
platform::errors::InvalidArgument("The Input(X) should be 1-D Tensor, "
"But now the dims of Input(X) is %d.",
in_dims.size()));
ctx->SetOutputDim("Out", {-1});
ctx->SetOutputDim("Index", in_dims);
......
......@@ -22,18 +22,12 @@ class WhereIndexOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Condition"), true,
platform::errors::NotFound(
"Input(Condition) of layers.where should not be null."));
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "where");
PADDLE_ENFORCE_GE(
ctx->GetInputDim("Condition").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of layers.where should not be null."));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "where");
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
}
......
......@@ -10708,6 +10708,10 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
check_variable_and_dtype(
x, "x",
['float32', 'float64', 'uint8', 'int16', 'int32', 'in64', 'uint8'],
"scale")
if in_dygraph_mode():
_scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale
out = core.ops.scale(x, 'scale',
......@@ -13256,6 +13260,7 @@ def where(condition):
out = layers.where(condition) # [[]]
"""
check_variable_and_dtype(condition, "condition", ['bool'], "where")
helper = LayerHelper("where_index", **locals())
out = helper.create_variable_for_type_inference(
......@@ -13324,6 +13329,8 @@ def unique(x, dtype='int32'):
out, index = fluid.layers.unique(x) # out is [2, 3, 1, 5]; index is [0, 1, 1, 2, 3, 1]
"""
check_variable_and_dtype(x, "x", ['float32', 'float64', 'int32', 'int64'],
"unique")
helper = LayerHelper("unique", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -13368,6 +13375,8 @@ def unique_with_counts(x, dtype='int32'):
# count is [1, 3, 1, 1]
# x.shape=(6,) out.shape=(4,), index.shape=(6,), count.shape=(4,)
"""
check_variable_and_dtype(x, "x", ['float32', 'float64', 'int32', 'int64'],
"unique_with_counts")
if not (dtype == 'int32' or dtype == 'int64'):
raise TypeError(
"Op unique_with_counts, index dtype must be int32 or int64")
......
......@@ -464,10 +464,23 @@ def sums(input, out=None):
# Sum of multiple Tensors, sum1 and x3 represents the same Variable (x3=x0+x1+x2, the value is [[6, ..., 6], ..., [6, ..., 6]])
sum1 = fluid.layers.sums(input=[x0, x1, x2], out=x3)
"""
check_type(input, 'input', (Variable, tuple, list), 'sums')
if isinstance(input, list) or isinstance(input, tuple):
for input_section in input:
check_variable_and_dtype(input_section, "input", \
['float32', 'float64', 'int32', 'int64'], 'sums')
else:
check_variable_and_dtype(input, "input", \
['float32', 'float64', 'int32', 'int64'], 'sums')
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
else:
check_variable_and_dtype(
out, "out", ['float32', 'float64', 'int32', 'int64'], 'sums')
helper.append_op(
type='sum',
inputs={'X': input},
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
......@@ -123,6 +124,20 @@ class TestScaleOpSelectedRows(unittest.TestCase):
self.check_with_place(place, 'in', 'in')
class TestScaleRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.scale([10])
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float16", name="input")
fluid.layers.scale(data)
self.assertRaises(TypeError, test_dtype)
# Add FP16 test
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
......
......@@ -241,6 +241,63 @@ class API_Test_Elementwise_Sum(unittest.TestCase):
self.assertEqual((result == expected_result).all(), True)
class TestRaiseSumError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.sum([11, 22])
self.assertRaises(TypeError, test_type)
def test_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
data2 = fluid.data(name="input2", shape=[10], dtype="int8")
fluid.layers.sum([data1, data2])
self.assertRaises(TypeError, test_dtype)
def test_dtype1():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
fluid.layers.sum(data1)
self.assertRaises(TypeError, test_dtype1)
class TestRaiseSumsError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.sums([11, 22])
self.assertRaises(TypeError, test_type)
def test_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
data2 = fluid.data(name="input2", shape=[10], dtype="int8")
fluid.layers.sums([data1, data2])
self.assertRaises(TypeError, test_dtype)
def test_dtype1():
data1 = fluid.data(name="input1", shape=[10], dtype="int8")
fluid.layers.sums(data1)
self.assertRaises(TypeError, test_dtype1)
def test_out_type():
data1 = fluid.data(name="input1", shape=[10], dtype="flaot32")
data2 = fluid.data(name="input2", shape=[10], dtype="float32")
fluid.layers.sums([data1, data2], out=[10])
self.assertRaises(TypeError, test_out_type)
def test_out_dtype():
data1 = fluid.data(name="input1", shape=[10], dtype="flaot32")
data2 = fluid.data(name="input2", shape=[10], dtype="float32")
out = fluid.data(name="out", shape=[10], dtype="int8")
fluid.layers.sums([data1, data2], out=out)
self.assertRaises(TypeError, test_out_dtype)
create_test_sum_fp16_class(TestSelectedRowsSumOp)
create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp)
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
......@@ -68,6 +69,20 @@ class TestRandom(TestUniqueOp):
self.outputs = {'Out': target_out, 'Index': target_index}
class TestUniqueRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.unique([10])
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float16", name="input")
fluid.layers.unique(data)
self.assertRaises(TypeError, test_dtype)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestOneGPU(TestUniqueOp):
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
......@@ -80,6 +81,20 @@ class TestRandom(TestUniqueWithCountsOp):
}
class TestUniqueWithCountsRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.unique_with_counts([10])
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float16", name="input")
fluid.layers.unique_with_counts(data)
self.assertRaises(TypeError, test_dtype)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestOneGPU(TestUniqueWithCountsOp):
......
......@@ -102,5 +102,19 @@ class TestWhereOpError(unittest.TestCase):
out = exe.run(fluid.default_main_program(), feed={'cond': cond_i})
class TestWhereRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.where([10])
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float32", name="input")
fluid.layers.where(data)
self.assertRaises(TypeError, test_dtype)
if __name__ == "__main__":
unittest.main()
......@@ -825,6 +825,17 @@ def elementwise_sum(inputs, name=None):
"""
helper = LayerHelper('elementwise_sum', **locals())
check_type(inputs, 'inputs', (Variable, tuple, list), 'elementwise_sum')
if isinstance(inputs, list) or isinstance(inputs, tuple):
if len(inputs) > 0:
for input in inputs:
check_variable_and_dtype(input, "inputs", \
['float32', 'float64', 'int32', 'int64'], 'elementwise_sum')
else:
check_variable_and_dtype(inputs, "inputs", \
['float32', 'float64', 'int32', 'int64'], 'elementwise_sum')
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('inputs'))
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册