提交 56bd92b8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4473 Implementing Posterize Op

Merge pull request !4473 from islam_amin/posterize_op
......@@ -42,6 +42,7 @@
#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_posterize_op.h"
#include "minddata/dataset/kernels/image/random_resize_op.h"
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
......@@ -142,6 +143,13 @@ PYBIND_REGISTER(RandomAffineOp, 1, ([](const py::module *m) {
py::arg("fill_value") = RandomAffineOp::kFillValue);
}));
PYBIND_REGISTER(RandomPosterizeOp, 1, ([](const py::module *m) {
(void)py::class_<RandomPosterizeOp, TensorOp, std::shared_ptr<RandomPosterizeOp>>(
*m, "RandomPosterizeOp", "Tensor operation to apply random posterize operation on an image.")
.def(py::init<uint8_t, uint8_t>(), py::arg("min_bit") = RandomPosterizeOp::kMinBit,
py::arg("max_bit") = RandomPosterizeOp::kMaxBit);
}));
PYBIND_REGISTER(
RandomResizeWithBBoxOp, 1, ([](const py::module *m) {
(void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
......
......@@ -32,6 +32,7 @@
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/random_posterize_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
#include "minddata/dataset/kernels/image/random_sharpness_op.h"
#include "minddata/dataset/kernels/image/random_solarize_op.h"
......@@ -217,6 +218,16 @@ std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob)
return op;
}
// Function to create RandomPosterizeOperation.
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit, uint8_t max_bit) {
auto op = std::make_shared<RandomPosterizeOperation>(min_bit, max_bit);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create RandomRotationOperation.
std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
bool expand, std::vector<float> center,
......@@ -725,6 +736,31 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
return tensor_op;
}
// RandomPosterizeOperation
RandomPosterizeOperation::RandomPosterizeOperation(uint8_t min_bit, uint8_t max_bit)
: min_bit_(min_bit), max_bit_(max_bit) {}
bool RandomPosterizeOperation::ValidateParams() {
if (min_bit_ < 1 || min_bit_ > 8) {
MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << min_bit_;
return false;
}
if (max_bit_ < 1 || max_bit_ > 8) {
MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << max_bit_;
return false;
}
if (max_bit_ < min_bit_) {
MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << max_bit_ << ", min = " << min_bit_;
return false;
}
return true;
}
std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
std::shared_ptr<RandomPosterizeOp> tensor_op = std::make_shared<RandomPosterizeOp>(min_bit_, max_bit_);
return tensor_op;
}
// Function to create RandomRotationOperation.
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
bool expand, std::vector<float> center,
......
......@@ -62,6 +62,7 @@ class RandomColorOperation;
class RandomColorAdjustOperation;
class RandomCropOperation;
class RandomHorizontalFlipOperation;
class RandomPosterizeOperation;
class RandomRotationOperation;
class RandomSharpnessOperation;
class RandomSolarizeOperation;
......@@ -220,6 +221,13 @@ std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob = 0.5);
/// \brief Function to create a RandomPosterize TensorOperation.
/// \notes Tensor operation to perform random posterize.
/// \param[in] min_bit - uint8_t representing the minimum bit in range. (Default=8)
/// \param[in] max_bit - uint8_t representing the maximum bit in range. (Default=8)
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit = 8, uint8_t max_bit = 8);
/// \brief Function to create a RandomRotation TensorOp
/// \notes Rotates the image according to parameters
/// \param[in] degrees A float vector size 2, representing the starting and ending degree
......@@ -521,6 +529,21 @@ class RandomHorizontalFlipOperation : public TensorOperation {
float probability_;
};
class RandomPosterizeOperation : public TensorOperation {
public:
explicit RandomPosterizeOperation(uint8_t min_bit = 8, uint8_t max_bit = 8);
~RandomPosterizeOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
uint8_t min_bit_;
uint8_t max_bit_;
};
class RandomRotationOperation : public TensorOperation {
public:
RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand,
......
......@@ -17,6 +17,7 @@ add_library(kernels-image OBJECT
mixup_batch_op.cc
normalize_op.cc
pad_op.cc
posterize_op.cc
random_affine_op.cc
random_color_adjust_op.cc
random_crop_decode_resize_op.cc
......@@ -27,6 +28,7 @@ add_library(kernels-image OBJECT
random_horizontal_flip_op.cc
random_horizontal_flip_with_bbox_op.cc
bounding_box_augment_op.cc
random_posterize_op.cc
random_resize_op.cc
random_rotation_op.cc
random_select_subpolicy_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/posterize_op.h"
#include <opencv2/imgcodecs.hpp>
namespace mindspore {
namespace dataset {
const uint8_t PosterizeOp::kBit = 8;
PosterizeOp::PosterizeOp(uint8_t bit) : bit_(bit) {}
Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
uint8_t mask_value = ~((uint8_t)(1 << (8 - bit_)) - 1);
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of <H,W,C> or <H,W>");
}
std::vector<uint8_t> lut_vector;
for (std::size_t i = 0; i < 256; i++) {
lut_vector.push_back(i & mask_value);
}
cv::Mat in_image = input_cv->mat();
cv::Mat output_img;
cv::LUT(in_image, lut_vector, output_img);
std::shared_ptr<CVTensor> result_tensor;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
*output = std::static_pointer_cast<Tensor>(result_tensor);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class PosterizeOp : public TensorOp {
public:
/// Default values
static const uint8_t kBit;
/// \brief Constructor
/// \param[in] bit: bits to use
explicit PosterizeOp(uint8_t bit = kBit);
~PosterizeOp() override = default;
std::string Name() const override { return kPosterizeOp; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
/// Member variables
private:
std::string kPosterizeOp = "PosterizeOp";
protected:
uint8_t bit_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_posterize_op.h"
#include <random>
#include <opencv2/imgcodecs.hpp>
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
const uint8_t RandomPosterizeOp::kMinBit = 8;
const uint8_t RandomPosterizeOp::kMaxBit = 8;
RandomPosterizeOp::RandomPosterizeOp(uint8_t min_bit, uint8_t max_bit)
: PosterizeOp(min_bit), min_bit_(min_bit), max_bit_(max_bit) {
rnd_.seed(GetSeed());
}
Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
bit_ = (min_bit_ == max_bit_) ? min_bit_ : std::uniform_int_distribution<uint8_t>(min_bit_, max_bit_)(rnd_);
return PosterizeOp::Compute(input, output);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/kernels/image/posterize_op.h"
namespace mindspore {
namespace dataset {
class RandomPosterizeOp : public PosterizeOp {
public:
/// Default values
static const uint8_t kMinBit;
static const uint8_t kMaxBit;
/// \brief Constructor
/// \param[in] min_bit: Minimum bit in range
/// \param[in] max_bit: Maximum bit in range
explicit RandomPosterizeOp(uint8_t min_bit = kMinBit, uint8_t max_bit = kMaxBit);
~RandomPosterizeOp() override = default;
std::string Name() const override { return kRandomPosterizeOp; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
/// Member variables
private:
std::string kRandomPosterizeOp = "RandomPosterizeOp";
uint8_t min_bit_;
uint8_t max_bit_;
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
......@@ -50,7 +50,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c
check_cut_mix_batch_c, check_posterize
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
......@@ -459,6 +459,26 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
super().__init__(prob)
class RandomPosterize(cde.RandomPosterizeOp):
"""
Reduce the number of bits for each color channel.
Args:
bits (sequence or int): Range of random posterize to compress image.
bits values should always be in range of [1,8], and include at
least one integer values in the given range. It should be in
(min, max) or integer format. If min=max, then it is a single fixed
magnitude operation (default=8).
"""
@check_posterize
def __init__(self, bits=(8, 8)):
self.bits = bits
if isinstance(bits, int):
bits = (bits, bits)
super().__init__(bits[0], bits[1])
class RandomVerticalFlip(cde.RandomVerticalFlipOp):
"""
Flip the input image vertically, randomly with a given probability.
......@@ -676,6 +696,7 @@ class RandomColor(cde.RandomColorOp):
def __init__(self, degrees=(0.1, 1.9)):
super().__init__(*degrees)
class RandomColorAdjust(cde.RandomColorAdjustOp):
"""
Randomly adjust the brightness, contrast, saturation, and hue of the input image.
......
......@@ -162,6 +162,28 @@ def check_crop(method):
return new_method
def check_posterize(method):
""""A wrapper that wraps a parameter checker to the original function(posterize operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[bits], _ = parse_user_args(method, *args, **kwargs)
if bits is not None:
type_check(bits, (list, tuple, int), "bits")
if isinstance(bits, int):
check_value(bits, [1, 8])
if isinstance(bits, (list, tuple)):
if len(bits) != 2:
raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
for item in bits:
check_uint8(item, "bits")
# also checks if min <= max
check_range(bits, [1, 8])
return method(self, *args, **kwargs)
return new_method
def check_resize_interpolation(method):
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
......
......@@ -789,6 +789,120 @@ TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomPosterizeFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterize with invalid params.";
// Create objects for the tensor ops
// Invalid max > 8
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize(1, 9);
EXPECT_EQ(posterize, nullptr);
// Invalid min < 1
posterize = vision::RandomPosterize(0, 8);
EXPECT_EQ(posterize, nullptr);
// min > max
posterize = vision::RandomPosterize(8, 1);
EXPECT_EQ(posterize, nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterizeSuccess1 with non-default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> posterize =
vision::RandomPosterize(1, 4);
EXPECT_NE(posterize, nullptr);
// Create a Map operation on ds
ds = ds->Map({posterize});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomPosterizeSuccess2 with default params.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize();
EXPECT_NE(posterize, nullptr);
// Create a Map operation on ds
ds = ds->Map({posterize});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomSharpness) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSharpness.";
......
......@@ -154,6 +154,10 @@ void CVOpCommon::CheckImageShapeAndData(const std::shared_ptr<Tensor> &output_te
expect_image_path = dir_path + "imagefolder/apple_expect_random_sharpness.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_sharpness.jpg";
break;
case kRandomPosterize:
expect_image_path = dir_path + "imagefolder/apple_expect_random_posterize.jpg";
actual_image_path = dir_path + "imagefolder/apple_actual_random_posterize.jpg";
break;
default:
MS_LOG(INFO) << "Not pass verification! Operation type does not exists.";
EXPECT_EQ(0, 1);
......
......@@ -42,6 +42,7 @@ class CVOpCommon : public Common {
kRandomSharpness,
kInvert,
kRandomAffine,
kRandomPosterize,
kAutoContrast,
kEqualize
};
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/random_posterize_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestRandomPosterizeOp : public UT::CVOP::CVOpCommon {
public:
MindDataTestRandomPosterizeOp() : CVOpCommon() {}
};
TEST_F(MindDataTestRandomPosterizeOp, TestOp1) {
MS_LOG(INFO) << "Doing testRandomPosterize.";
std::shared_ptr<Tensor> output_tensor;
std::unique_ptr<RandomPosterizeOp> op(new RandomPosterizeOp(1, 1));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
CheckImageShapeAndData(output_tensor, kRandomPosterize);
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomPosterize op in DE
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_list, save_and_check_md5, \
config_get_set_seed, config_get_set_num_parallel_workers
GENERATE_GOLDEN = False
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_random_posterize_op_c(plot=False, run_golden=True):
"""
Test RandomPosterize in C transformations
"""
logger.info("test_random_posterize_op_c")
original_seed = config_get_set_seed(55)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# define map operations
transforms1 = [
c_vision.Decode(),
c_vision.RandomPosterize((1, 8))
]
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(input_columns=["image"], operations=transforms1)
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
image_posterize = []
image_original = []
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
image1 = item1["image"]
image2 = item2["image"]
image_posterize.append(image1)
image_original.append(image2)
if run_golden:
# check results with md5 comparison
filename = "random_posterize_01_result_c.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
if plot:
visualize_list(image_original, image_posterize)
# Restore configuration
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_random_posterize_op_fixed_point_c(plot=False, run_golden=True):
"""
Test RandomPosterize in C transformations with fixed point
"""
logger.info("test_random_posterize_op_c")
# define map operations
transforms1 = [
c_vision.Decode(),
c_vision.RandomPosterize(1)
]
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(input_columns=["image"], operations=transforms1)
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
image_posterize = []
image_original = []
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
image1 = item1["image"]
image2 = item2["image"]
image_posterize.append(image1)
image_original.append(image2)
if run_golden:
# check results with md5 comparison
filename = "random_posterize_fixed_point_01_result_c.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
if plot:
visualize_list(image_original, image_posterize)
def test_random_posterize_exception_bit():
"""
Test RandomPosterize: out of range input bits and invalid type
"""
logger.info("test_random_posterize_exception_bit")
# Test max > 8
try:
_ = c_vision.RandomPosterize((1, 9))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
# Test min < 1
try:
_ = c_vision.RandomPosterize((0, 7))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
# Test max < min
try:
_ = c_vision.RandomPosterize((8, 1))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
# Test wrong type (not uint8)
try:
_ = c_vision.RandomPosterize(1.1)
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Argument bits with value 1.1 is not of type (<class 'list'>, <class 'tuple'>, <class 'int'>)."
# Test wrong number of bits
try:
_ = c_vision.RandomPosterize((1, 1, 1))
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
if __name__ == "__main__":
test_random_posterize_op_c(plot=True)
test_random_posterize_op_fixed_point_c(plot=True)
test_random_posterize_exception_bit()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册