未验证 提交 b822f74c 编写于 作者: W wawltor 提交者: GitHub

Add the error raise for some operators, add some test cases

Add the error raise for those cases
aassign isfinite linspace ones_like zeros_like zeros ones
上级 fb34bdb4
...@@ -75,10 +75,10 @@ class AssignKernel { ...@@ -75,10 +75,10 @@ class AssignKernel {
if (x == nullptr) { if (x == nullptr) {
return; return;
} }
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto *out = ctx.OutputVar("Out"); auto *out = ctx.OutputVar("Out");
PADDLE_ENFORCE(
out != nullptr,
"The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace()); auto &dev_ctx = *pool.Get(ctx.GetPlace());
......
...@@ -28,8 +28,9 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -28,8 +28,9 @@ class AssignValueOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(
"Output(Out) of AssignValueOp should not be null."); ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
......
...@@ -23,10 +23,8 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { ...@@ -23,10 +23,8 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fill_any_like");
"Input(X) of FillAnyLikeOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fill_any_like");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillAnyLikeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -40,15 +40,20 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -40,15 +40,20 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
auto common_type_value = static_cast<CommonType>(value); auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
(common_type_value >= (common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) && static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <= (common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())), static_cast<CommonType>(std::numeric_limits<T>::max())),
"filled value is out of range for targeted type in fill_any_like " true, platform::errors::InvalidArgument(
"kernel"); "filled value is out of range for"
" targeted type in fill_any_like, your kernel type is %s"
PADDLE_ENFORCE(!std::isnan(value), "filled value is NaN"); ", please check value you set.",
typeid(T).name()));
PADDLE_ENFORCE_EQ(
std::isnan(value), false,
platform::errors::InvalidArgument("filled value should not be NaN,"
" but received NaN"));
math::SetConstant<DeviceContext, T> setter; math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), out, setter(context.template device_context<DeviceContext>(), out,
......
...@@ -22,10 +22,8 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -22,10 +22,8 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fill_zeros_like");
"Input(X) of FillZerosLikeOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fill_zeros_like");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -24,25 +24,29 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -24,25 +24,29 @@ class LinspaceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Start"), PADDLE_ENFORCE(ctx->HasInput("Start"),
"Input(Start) of LinspaceOp should not be null."); "Input(Start) of LinspaceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Stop"), OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
"Input(Stop) of LinspaceOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
PADDLE_ENFORCE(ctx->HasInput("Num"), OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
"Input(Num) of LinspaceOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "linspace");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(OUt) of LinspaceOp should not be null.");
auto s_dims = ctx->GetInputDim("Start"); auto s_dims = ctx->GetInputDim("Start");
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1), PADDLE_ENFORCE_EQ((s_dims.size() == 1) && (s_dims[0] == 1), true,
"The shape of Input(Start) should be [1]."); platform::errors::InvalidArgument(
"The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = ctx->GetInputDim("Stop"); auto e_dims = ctx->GetInputDim("Stop");
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1), PADDLE_ENFORCE_EQ((e_dims.size() == 1) && (e_dims[0] == 1), true,
"The shape of Input(Stop) should be [1]."); platform::errors::InvalidArgument(
"The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto step_dims = ctx->GetInputDim("Num"); auto step_dims = ctx->GetInputDim("Num");
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1), PADDLE_ENFORCE_EQ(
"The shape of Input(Num) should be [1]."); (step_dims.size() == 1) && (step_dims[0] == 1), true,
platform::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
step_dims));
ctx->SetOutputDim("Out", {-1}); ctx->SetOutputDim("Out", {-1});
} }
......
...@@ -961,8 +961,10 @@ def ones(shape, dtype, force_cpu=False): ...@@ -961,8 +961,10 @@ def ones(shape, dtype, force_cpu=False):
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]] data = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]]
""" """
assert isinstance(shape, list) or isinstance( check_type(shape, 'shape', (list, tuple), 'ones')
shape, tuple), "The shape's type should be list or tuple." check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'ones')
assert reduce(lambda x, y: x * y, assert reduce(lambda x, y: x * y,
shape) > 0, "The shape is invalid: %s." % (str(shape)) shape) > 0, "The shape is invalid: %s." % (str(shape))
return fill_constant(value=1.0, **locals()) return fill_constant(value=1.0, **locals())
...@@ -990,6 +992,7 @@ def zeros(shape, dtype, force_cpu=False): ...@@ -990,6 +992,7 @@ def zeros(shape, dtype, force_cpu=False):
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
""" """
check_type(shape, 'shape', (list, tuple), 'zeros')
check_dtype(dtype, 'create data type', check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'zeros') 'zeros')
...@@ -1174,7 +1177,10 @@ def isfinite(x): ...@@ -1174,7 +1177,10 @@ def isfinite(x):
dtype="float32") dtype="float32")
out = fluid.layers.isfinite(var) out = fluid.layers.isfinite(var)
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"isfinite")
helper = LayerHelper("isfinite", **locals()) helper = LayerHelper("isfinite", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -1273,12 +1279,25 @@ def linspace(start, stop, num, dtype): ...@@ -1273,12 +1279,25 @@ def linspace(start, stop, num, dtype):
""" """
helper = LayerHelper("linspace", **locals()) helper = LayerHelper("linspace", **locals())
check_type(start, 'start', (Variable, float, int), linspace)
check_type(stop, 'stop', (Variable, float, int), linspace)
check_type(num, 'num', (Variable, float, int), linspace)
if not isinstance(start, Variable): if not isinstance(start, Variable):
start = fill_constant([1], dtype, start) start = fill_constant([1], dtype, start)
else:
check_variable_and_dtype(start, "start", ["float32", "float64"],
"linspace")
if not isinstance(stop, Variable): if not isinstance(stop, Variable):
stop = fill_constant([1], dtype, stop) stop = fill_constant([1], dtype, stop)
else:
check_variable_and_dtype(stop, "stop", ["float32", "float64"],
"linspace")
if not isinstance(num, Variable): if not isinstance(num, Variable):
num = fill_constant([1], 'int32', num) num = fill_constant([1], 'int32', num)
else:
check_variable_and_dtype(num, "num", ["int32"], "linspace")
out = helper.create_variable_for_type_inference(dtype=start.dtype) out = helper.create_variable_for_type_inference(dtype=start.dtype)
...@@ -1315,9 +1334,16 @@ def zeros_like(x, out=None): ...@@ -1315,9 +1334,16 @@ def zeros_like(x, out=None):
""" """
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'ones_like')
helper = LayerHelper("zeros_like", **locals()) helper = LayerHelper("zeros_like", **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
helper.append_op( helper.append_op(
type='fill_zeros_like', inputs={'X': [x]}, outputs={'Out': [out]}) type='fill_zeros_like', inputs={'X': [x]}, outputs={'Out': [out]})
out.stop_gradient = True out.stop_gradient = True
...@@ -1462,10 +1488,16 @@ def ones_like(x, out=None): ...@@ -1462,10 +1488,16 @@ def ones_like(x, out=None):
data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0] data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0]
""" """
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'ones_like')
helper = LayerHelper("ones_like", **locals()) helper = LayerHelper("ones_like", **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
helper.append_op( helper.append_op(
type='fill_any_like', type='fill_any_like',
inputs={'X': [x]}, inputs={'X': [x]},
......
...@@ -247,6 +247,34 @@ class TestOnesZerosError(unittest.TestCase): ...@@ -247,6 +247,34 @@ class TestOnesZerosError(unittest.TestCase):
self.assertRaises(ValueError, test_device_error4) self.assertRaises(ValueError, test_device_error4)
def test_ones_like_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
fluid.layers.ones_like([10], dtype="float")
self.assertRaises(TypeError, test_ones_like_type_error)
def test_ones_like_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float16")
fluid.layers.ones_like(data, dtype="float32")
self.assertRaises(TypeError, test_ones_like_dtype_error)
def test_ones_like_out_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
fluid.layers.ones_like(data, dtype="float32", out=[10])
self.assertRaises(TypeError, test_ones_like_out_type_error)
def test_ones_like_out_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
out = fluid.data(name="out", shape=[10], dtype="float16")
fluid.layers.ones_like(data, dtype="float32", out=out)
self.assertRaises(TypeError, test_ones_like_out_dtype_error)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -457,6 +457,30 @@ class ApiOnesZerosError(unittest.TestCase): ...@@ -457,6 +457,30 @@ class ApiOnesZerosError(unittest.TestCase):
self.assertRaises(ValueError, test_error2) self.assertRaises(ValueError, test_error2)
def test_error3():
with fluid.program_guard(fluid.Program()):
ones = fluid.layers.ones(shape=10, dtype="int64")
self.assertRaises(TypeError, test_error3)
def test_error4():
with fluid.program_guard(fluid.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4)
def test_error5():
with fluid.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=10, dtype="int64")
self.assertRaises(TypeError, test_error5)
def test_error6():
with fluid.program_guard(fluid.Program()):
ones = fluid.layers.zeros(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error6)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from op_test import OpTest from op_test import OpTest
...@@ -46,5 +47,36 @@ class TestFillZerosLike2OpFp64(TestFillZerosLike2Op): ...@@ -46,5 +47,36 @@ class TestFillZerosLike2OpFp64(TestFillZerosLike2Op):
self.dtype = np.float64 self.dtype = np.float64
class TestZerosError(unittest.TestCase):
def test_errors(self):
def test_zeros_like_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
fluid.layers.zeros_like([10], dtype="float")
self.assertRaises(TypeError, test_zeros_like_type_error)
def test_zeros_like_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float16")
fluid.layers.zeros_like(data, dtype="float32")
self.assertRaises(TypeError, test_zeros_like_dtype_error)
def test_zeros_like_out_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
fluid.layers.zeros_like(data, dtype="float32", out=[10])
self.assertRaises(TypeError, test_zeros_like_out_type_error)
def test_zeros_like_out_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
out = fluid.data(name="out", shape=[10], dtype="float16")
fluid.layers.zeros_like(data, dtype="float32", out=out)
self.assertRaises(TypeError, test_zeros_like_out_dtype_error)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import unittest
class TestInf(OpTest): class TestInf(OpTest):
...@@ -38,6 +40,20 @@ class TestInf(OpTest): ...@@ -38,6 +40,20 @@ class TestInf(OpTest):
self.check_output() self.check_output()
class TestRaiseError(unittest.TestCase):
def test_errors(self):
def test_type():
fluid.layers.isfinite([10])
self.assertRaises(TypeError, test_type)
def test_dtype():
data = fluid.data(shape=[10], dtype="float16", name="input")
fluid.layers.isfinite(data)
self.assertRaises(TypeError, test_dtype)
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestFP16Inf(TestInf): class TestFP16Inf(TestInf):
......
...@@ -99,6 +99,39 @@ class TestLinspaceOpError(unittest.TestCase): ...@@ -99,6 +99,39 @@ class TestLinspaceOpError(unittest.TestCase):
self.assertRaises(ValueError, test_device_value) self.assertRaises(ValueError, test_device_value)
def test_start_type():
fluid.layers.linspace([0], 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_type)
def test_end_dtype():
fluid.layers.linspace(0, [10], 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype)
def test_step_dtype():
fluid.layers.linspace(0, 10, [0], dtype="float32")
self.assertRaises(TypeError, test_step_dtype)
def test_start_dtype():
start = fluid.data(shape=[1], type="int32", name="start")
fluid.layers.linspace(start, 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_dtype)
def test_end_dtype():
end = fluid.data(shape=[1], type="int32", name="end")
fluid.layers.linspace(0, end, 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype)
def test_step_dtype():
step = fluid.data(shape=[1], type="int32", name="step")
fluid.layers.linspace(0, 10, step, dtype="float32")
self.assertRaises(TypeError, test_step_dtype)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册