提交 e9e40b68 编写于 作者: T Tinazhang

Bug fix

上级 6cbde2b3
......@@ -55,11 +55,11 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
// apply C++ ops (note: python OPs are not accepted)
if (count == 1) {
(**tensor_op).Compute(input, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output));
} else if (count % 2 == 0) {
(**tensor_op).Compute(*output, even_out_ptr);
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr));
} else {
(**tensor_op).Compute(even_out, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output));
}
count++;
}
......
......@@ -226,6 +226,27 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
def test_cpp_uniform_augment_random_crop_ut():
batch_size=2
cifar10_dir = "../data/dataset/testCifar10Data"
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
transforms_ua = [
C.RandomCrop(size=[224, 224]),
C.RandomHorizontalFlip()
]
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
ds1 = ds1.map(input_columns="image", operations=uni_aug)
# apply DatasetOps
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
num_batches = 0
try:
for data in ds1.create_dict_iterator():
num_batches += 1
except BaseException as e:
assert "Crop size" in str(e)
if __name__ == "__main__":
test_uniform_augment(num_ops=1)
......@@ -233,3 +254,4 @@ if __name__ == "__main__":
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
test_cpp_uniform_augment_random_crop_ut()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册