未验证 提交 9549b786 编写于 作者: L Liufang Sang 提交者: GitHub

OP Normal, Uniform, Xavier Initializer, smooth_l1, mean_iou error message enhancement (#23751)

* enhance error message test=develop

* enhance error message test=develop

* change to INOUT_CHECK  test=develop
上级 840ac2b3
...@@ -50,16 +50,18 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -50,16 +50,18 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GaussianRandom");
"Output(Out) of GaussianRandomOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim)); temp.push_back(static_cast<int64_t>(dim));
} }
PADDLE_ENFORCE(shape.size() > 0UL, PADDLE_ENFORCE_GT(shape.size(), 0UL,
"shape can be one int or array. shape must be set."); platform::errors::InvalidArgument(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0."));
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
......
...@@ -22,16 +22,14 @@ class MeanIoUOp : public framework::OperatorWithKernel { ...@@ -22,16 +22,14 @@ class MeanIoUOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predictions"), OP_INOUT_CHECK(ctx->HasInput("Predictions"), "Input", "Predictions",
"Input (Predictions) of MeanIoU op should not be null."); "MeanIoU");
PADDLE_ENFORCE(ctx->HasInput("Labels"), OP_INOUT_CHECK(ctx->HasInput("Labels"), "Input", "Labels", "MeanIoU");
"Input (labels) of MeanIoU op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("OutMeanIou"), "Output", "OutMeanIou",
PADDLE_ENFORCE(ctx->HasOutput("OutMeanIou"), "MeanIoU");
"Output (OutMeanIou) of MeanIoU op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("OutWrong"), "Output", "OutWrong", "MeanIoU");
PADDLE_ENFORCE(ctx->HasOutput("OutWrong"), OP_INOUT_CHECK(ctx->HasOutput("OutCorrect"), "Output", "OutCorrect",
"Output (OutWrong) of MeanIoU op should not be null."); "MeanIoU");
PADDLE_ENFORCE(ctx->HasOutput("OutCorrect"),
"Output (OutWrong) of MeanIoU op should not be null.");
int64_t num_classes = int64_t num_classes =
static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes")); static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes"));
......
...@@ -22,18 +22,32 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -22,18 +22,32 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"), PADDLE_ENFORCE_EQ(
"Input (Out) of accuracy op should not be null."); ctx->HasInput("Out"), true,
PADDLE_ENFORCE(ctx->HasInput("Indices"), platform::errors::NotFound("Input (Out) of AccuracyOp is not found."));
"Input (Indices) of accuracy op should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true,
PADDLE_ENFORCE(ctx->HasInput("Label"), platform::errors::NotFound(
"Input (Label) of accuracy op should not be null."); "Input (Indices) of AccuracyOp is not found."));
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Output (Accuracy) of AccuracyOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE(ctx->HasOutput("Correct"), "Input (Label) of AccuracyOp is not found."));
"Output (Correct) of AccuracyOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Accuracy"), true,
PADDLE_ENFORCE(ctx->HasOutput("Total"), platform::errors::NotFound(
"Output (Total) of AccuracyOp should not be null."); "Output (Accuracy) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Correct"), true,
platform::errors::NotFound(
"Output (Correct) of AccuracyOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Total"), true,
platform::errors::NotFound(
"Output (Total) of AccuracyOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Accuracy");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Accuracy"), "Output", "Accuracy",
"Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Correct"), "Output", "Correct", "Accuracy");
OP_INOUT_CHECK(ctx->HasOutput("Total"), "Output", "Total", "Accuracy");
auto inference_dim = ctx->GetInputDim("Out"); auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label"); auto label_dim = ctx->GetInputDim("Label");
...@@ -42,22 +56,26 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -42,22 +56,26 @@ class AccuracyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
label_dim.size(), 2, label_dim.size(), 2,
"ShapeError: label's dimensions of AccuracyOp must be 2. " platform::errors::InvalidArgument(
"But received label's dimensions = %d, label's shape = [%s]", "ShapeError: label's dimensions of AccuracyOp must be 2. "
label_dim.size(), label_dim); "But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(), label_dim));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(label_dim[1], 1, PADDLE_ENFORCE_EQ(label_dim[1], 1,
"ShapeError: label's second dimension of " platform::errors::InvalidArgument(
"AccuracyOp must be 1. But received label's " "ShapeError: label's second dimension of "
"second dimension is = %d, label's shape = [%s]", "AccuracyOp must be 1. But received label's "
label_dim[1], label_dim); "second dimension is = %d, label's shape = [%s]",
label_dim[1], label_dim));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
inference_dim[0], label_dim[0], inference_dim[0], label_dim[0],
"ShapeError: the output's num_rows of AccuracyOp must be" platform::errors::InvalidArgument(
" the same as label's num_rows. But received output's " "ShapeError: the output's num_rows of AccuracyOp must be"
"shape = [%s], label's shape = [%s], output's num_rows = %d, label's " " the same as label's num_rows. But received output's "
"num_rows = %d", "shape = [%s], label's shape = [%s], output's num_rows = %d, "
inference_dim, label_dim, inference_dim[0], label_dim[0]); "label's "
"num_rows = %d",
inference_dim, label_dim, inference_dim[0], label_dim[0]));
} }
ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Accuracy", {1});
......
...@@ -56,8 +56,6 @@ template <typename T> ...@@ -56,8 +56,6 @@ template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel<T> { class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto* inference = ctx.Input<Tensor>("Out"); auto* inference = ctx.Input<Tensor>("Out");
auto* indices = ctx.Input<Tensor>("Indices"); auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
......
...@@ -23,8 +23,8 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -23,8 +23,8 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SmoothL1Loss");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "SmoothL1Loss");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
...@@ -34,14 +34,20 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -34,14 +34,20 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
check = false; check = false;
} }
if (check) { if (check) {
PADDLE_ENFORCE_EQ(x_dims, y_dims); PADDLE_ENFORCE_EQ(x_dims, y_dims,
platform::errors::InvalidArgument(
"Input(X) ans Input(Y) of SmoothL1LossOp should "
"have the same size"));
} }
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The tensor rank of Input(X) should not be less than 2."); platform::errors::InvalidArgument(
"The tensor rank of Input(X) of SmoothL1LossOp "
"should not be less than 2."));
if (ctx->HasInput("InsideWeight")) { if (ctx->HasInput("InsideWeight")) {
PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"), PADDLE_ENFORCE_EQ(ctx->HasInput("OutsideWeight"), true,
"If weights are provided, must specify both " platform::errors::InvalidArgument(
"inside and outside weights."); "If weights are provided, must specify both "
"inside and outside weights."));
auto dims = ctx->GetInputDim("InsideWeight"); auto dims = ctx->GetInputDim("InsideWeight");
bool check = true; bool check = true;
if ((!ctx->IsRuntime()) && if ((!ctx->IsRuntime()) &&
...@@ -49,7 +55,10 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -49,7 +55,10 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
check = false; check = false;
} }
if (check) { if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims); PADDLE_ENFORCE_EQ(x_dims, dims,
platform::errors::InvalidArgument(
"Input(X) ans Input(InsideWeight) of "
"SmoothL1LossOp should have the same size"));
} }
dims = ctx->GetInputDim("OutsideWeight"); dims = ctx->GetInputDim("OutsideWeight");
...@@ -59,7 +68,10 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -59,7 +68,10 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
check = false; check = false;
} }
if (check) { if (check) {
PADDLE_ENFORCE_EQ(dims, x_dims); PADDLE_ENFORCE_EQ(x_dims, dims,
platform::errors::InvalidArgument(
"Input(X) ans Input(OutsideWeight) of "
"SmoothL1LossOp should have the same size"));
} }
} }
...@@ -134,13 +146,17 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { ...@@ -134,13 +146,17 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(out_dims.size(), 2, PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2."); platform::errors::InvalidArgument(
"The tensor rank of Input(Out@Grad) should be 2."));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0], PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be " platform::errors::InvalidArgument(
"same as input."); "The 1st dimension of Input(Out@Grad) must be "
"same as input in SmoothL1LossGradOp."));
PADDLE_ENFORCE_EQ(out_dims[1], 1, PADDLE_ENFORCE_EQ(out_dims[1], 1,
"The 2nd dimension of Input(Out@Grad) must be 1."); platform::errors::InvalidArgument(
"The 2nd dimension of Input(Out@Grad) must be 1 in "
"SmoothL1LossGradOp."));
} }
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
...@@ -74,12 +74,13 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -74,12 +74,13 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
static_cast<unsigned int>(ctx.Attr<int>("diag_step")); static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val")); auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
if (diag_num > 0) { if (diag_num > 0) {
PADDLE_ENFORCE_GT(size, (diag_num - 1) * (diag_step + 1), PADDLE_ENFORCE_GT(
"ShapeError: the diagonal's elements is equal (num-1) " size, (diag_num - 1) * (diag_step + 1),
"* (step-1) with num %d, step %d," platform::errors::InvalidArgument(
"It should be smaller than %d, but received %d", "ShapeInvalid: the diagonal's elements is equal (num-1) "
diag_num, diag_step, (diag_num - 1) * (diag_step + 1), "* (step-1) with num %d, step %d,"
size); "It should be smaller than %d, but received %d",
diag_num, diag_step, (diag_num - 1) * (diag_step + 1), size));
for (int64_t i = 0; i < diag_num; ++i) { for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i; int64_t pos = i * diag_step + i;
data[pos] = diag_val; data[pos] = diag_val;
...@@ -93,25 +94,27 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -93,25 +94,27 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandom");
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"), PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
ctx->Attrs().Get<float>("max"), ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max"); platform::errors::InvalidArgument(
"uniform_random's min must less then max"));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0, PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
"diag_num must greater than or equal 0"); platform::errors::InvalidArgument(
"diag_num must greater than or equal 0"));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0, PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
"diag_step must greater than or equal 0"); platform::errors::InvalidArgument(
"diag_step must greater than or equal 0"));
if (ctx->HasInputs("ShapeTensorList")) { if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape // top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList"); auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
inputs_name.size(), 0, inputs_name.size(), 0,
"Input(ShapeTensorList)'size of Op(uniform_random) can't be zero." platform::errors::InvalidArgument(
"Please check the Attr(shape)'s size of" "Input(ShapeTensorList)'size of Op(uniform_random) can't be zero."
"Op(fluid.layers.uniform_random).)"); "Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)"));
auto out_dims = std::vector<int>(inputs_name.size(), -1); auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
...@@ -122,10 +125,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -122,10 +125,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
auto shape_dims = ctx->GetInputDim("ShapeTensor"); auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
shape_dims.size(), 1, shape_dims.size(), 1,
"ShapeError: Input(ShapeTensor)' dimension size of " platform::errors::InvalidArgument(
"Op(uniform_random) must be 1." "ShapeError: Input(ShapeTensor)' dimension size of "
"But received ShapeTensor's dimensions = %d, shape = [%s]", "Op(uniform_random) must be 1."
shape_dims.size(), shape_dims); "But received ShapeTensor's dimensions = %d, shape = [%s]",
shape_dims.size(), shape_dims));
int num_ele = 1; int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) { for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i]; num_ele *= shape_dims[i];
...@@ -136,11 +140,12 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -136,11 +140,12 @@ class UniformRandomOp : public framework::OperatorWithKernel {
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(shape.empty(), false,
shape.empty(), false, platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no Input(ShapeTensor),the " "if there is no Input(ShapeTensorList) and no "
"attr(shape) information must " "Input(ShapeTensor),the "
"be set by Attr(shape)."); "attr(shape) information must "
"be set by Attr(shape)."));
std::vector<int64_t> tensor_shape; std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size()); tensor_shape.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
from .data_feeder import check_variable_and_dtype, check_type, check_dtype
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
...@@ -216,8 +217,10 @@ class UniformInitializer(Initializer): ...@@ -216,8 +217,10 @@ class UniformInitializer(Initializer):
Returns: Returns:
the initialization op the initialization op
""" """
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"uniform_random")
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -303,8 +306,10 @@ class NormalInitializer(Initializer): ...@@ -303,8 +306,10 @@ class NormalInitializer(Initializer):
Returns: Returns:
the initialization op the initialization op
""" """
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"guassian_random")
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -494,8 +499,10 @@ class XavierInitializer(Initializer): ...@@ -494,8 +499,10 @@ class XavierInitializer(Initializer):
Returns: Returns:
the initialization op the initialization op
""" """
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"xavier_init")
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
# If fan_in and fan_out are passed, use them # If fan_in and fan_out are passed, use them
......
...@@ -5653,8 +5653,11 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -5653,8 +5653,11 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
# [0.20541131]], dtype=float32)] # [0.20541131]], dtype=float32)]
""" """
check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'smooth_l1_loss')
check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'smooth_l1_loss')
helper = LayerHelper('smooth_l1_loss', **locals()) helper = LayerHelper('smooth_l1_loss', **locals())
diff = helper.create_variable_for_type_inference(dtype=x.dtype) diff = helper.create_variable_for_type_inference(dtype=x.dtype)
loss = helper.create_variable_for_type_inference(dtype=x.dtype) loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
...@@ -8375,6 +8378,9 @@ def mean_iou(input, label, num_classes): ...@@ -8375,6 +8378,9 @@ def mean_iou(input, label, num_classes):
num_classes) num_classes)
""" """
helper = LayerHelper('mean_iou', **locals()) helper = LayerHelper('mean_iou', **locals())
check_variable_and_dtype(input, 'Predictions', ['int32', 'int64'],
'mean_iou')
check_variable_and_dtype(label, 'Labels', ['int32', 'int64'], 'mean_iou')
dtype = helper.input_dtype() dtype = helper.input_dtype()
out_mean_iou = helper.create_variable_for_type_inference(dtype='float32') out_mean_iou = helper.create_variable_for_type_inference(dtype='float32')
out_wrong = helper.create_variable_for_type_inference(dtype='int32') out_wrong = helper.create_variable_for_type_inference(dtype='int32')
......
...@@ -72,5 +72,42 @@ class TestGaussianRandomOp(unittest.TestCase): ...@@ -72,5 +72,42 @@ class TestGaussianRandomOp(unittest.TestCase):
pass pass
class TestGaussianRandomOpError(unittest.TestCase):
def setUp(self):
self.op_type = "gaussian_random"
self.inputs = {}
self.use_mkldnn = False
self.attrs = {
"shape": [1000, 784],
"mean": .0,
"std": 1.,
"seed": 10,
"use_mkldnn": self.use_mkldnn
}
self.outputs = ["Out"]
def test_errors(self):
program = fluid.Program()
with fluid.program_guard(fluid.Program(), program):
input_data = numpy.random.random((2, 4)).astype("float32")
block = program.global_block()
vout = block.create_var(name="Out", dtype='int32')
normal_initializer = fluid.initializer.NormalInitializer(
loc=0.0, scale=1.0, seed=0)
def test_Variable():
# the input type must be Variable
normal_initializer(input_data)
self.assertRaises(TypeError, test_Variable)
def test_type():
# dtype must be float32 or float64
normal_initializer(vout)
self.assertRaises(TypeError, test_type)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
def compute_mean_iou(predictions, labels, num_classes, in_wrongs, in_corrects, def compute_mean_iou(predictions, labels, num_classes, in_wrongs, in_corrects,
...@@ -112,5 +113,20 @@ class TestCase1(TestMeanIOUOp): ...@@ -112,5 +113,20 @@ class TestCase1(TestMeanIOUOp):
self.in_mean_iou_num = 2 self.in_mean_iou_num = 2
class TestMeanIOUOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
y1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.mean_iou, x1, y1)
# The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="float32")
y2 = fluid.layers.data(name='x2', shape=[4], dtype="float32")
self.assertRaises(TypeError, fluid.layers.mean_iou, x2, y2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
def smooth_l1_loss_forward(val, sigma2): def smooth_l1_loss_forward(val, sigma2):
...@@ -105,5 +106,20 @@ class TestSmoothL1LossOp2(OpTest): ...@@ -105,5 +106,20 @@ class TestSmoothL1LossOp2(OpTest):
no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight'])) no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight']))
class TestSmoothL1LossOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
# The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
y1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.smooth_l1, x1, y1)
# The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
y2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.smooth_l1, x2, y2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册