diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc index 5c9650688c6c17848adfd71595e3dbb7536a2f8c..0d4f392681f9e1efc6e3cd228ece7805d51e147a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/posterize_op.cc @@ -40,6 +40,8 @@ Status PosterizeOp::Compute(const std::shared_ptr &input, std::shared_pt } cv::Mat in_image = input_cv->mat(); cv::Mat output_img; + CHECK_FAIL_RETURN_UNEXPECTED(in_image.depth() == CV_8U || in_image.depth() == CV_8S, + "Input image data type can not be float, but got " + input->type().ToString()); cv::LUT(in_image, lut_vector, output_img); std::shared_ptr result_tensor; RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor)); diff --git a/tests/ut/python/dataset/test_random_posterize.py b/tests/ut/python/dataset/test_random_posterize.py index 7adf758af93fc80dd4559c515a16aa09efd73d56..5748a49e736524a847e98665d51c171abe17d8f3 100644 --- a/tests/ut/python/dataset/test_random_posterize.py +++ b/tests/ut/python/dataset/test_random_posterize.py @@ -142,8 +142,29 @@ def test_random_posterize_exception_bit(): logger.info("Got an exception in DE: {}".format(str(e))) assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2." +def test_rescale_with_random_posterize(): + """ + Test RandomPosterize: only support CV_8S/CV_8U + """ + logger.info("test_rescale_with_random_posterize") + + DATA_DIR_10 = "../data/dataset/testCifar10Data" + dataset = ds.Cifar10Dataset(DATA_DIR_10) + + rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0) + dataset = dataset.map(input_columns=["image"], operations=rescale_op) + + random_posterize_op = c_vision.RandomPosterize((4, 8)) + dataset = dataset.map(input_columns=["image"], operations=random_posterize_op, num_parallel_workers=1) + + try: + _ = dataset.output_shapes() + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Input image data type can not be float" in str(e) if __name__ == "__main__": skip_test_random_posterize_op_c(plot=True) skip_test_random_posterize_op_fixed_point_c(plot=True) test_random_posterize_exception_bit() + test_rescale_with_random_posterize()