Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9b503e4f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b503e4f
编写于
8月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4955 Fixes for Dynamic Augmentation Ops
Merge pull request !4955 from MahdiRahmaniHanzaki/dynamic-ops-fix
上级
528fb810
a5f9b8f9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
285 addition
and
56 deletion
+285
-56
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+2
-2
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
...e/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
+12
-5
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
...re/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
+12
-5
mindspore/dataset/transforms/vision/c_transforms.py
mindspore/dataset/transforms/vision/c_transforms.py
+2
-2
mindspore/dataset/transforms/vision/py_transforms.py
mindspore/dataset/transforms/vision/py_transforms.py
+4
-31
mindspore/dataset/transforms/vision/py_transforms_util.py
mindspore/dataset/transforms/vision/py_transforms_util.py
+28
-6
mindspore/dataset/transforms/vision/validators.py
mindspore/dataset/transforms/vision/validators.py
+2
-0
tests/ut/cpp/dataset/c_api_transforms_test.cc
tests/ut/cpp/dataset/c_api_transforms_test.cc
+52
-2
tests/ut/python/dataset/test_cutmix_batch_op.py
tests/ut/python/dataset/test_cutmix_batch_op.py
+69
-1
tests/ut/python/dataset/test_mixup_op.py
tests/ut/python/dataset/test_mixup_op.py
+82
-2
tests/ut/python/dataset/test_random_affine.py
tests/ut/python/dataset/test_random_affine.py
+20
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
9b503e4f
...
...
@@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
:
image_batch_format_
(
image_batch_format
),
alpha_
(
alpha
),
prob_
(
prob
)
{}
bool
CutMixBatchOperation
::
ValidateParams
()
{
if
(
alpha_
<
0
)
{
if
(
alpha_
<
=
0
)
{
MS_LOG
(
ERROR
)
<<
"CutMixBatch: alpha cannot be negative."
;
return
false
;
}
...
...
@@ -434,7 +434,7 @@ std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<H
MixUpBatchOperation
::
MixUpBatchOperation
(
float
alpha
)
:
alpha_
(
alpha
)
{}
bool
MixUpBatchOperation
::
ValidateParams
()
{
if
(
alpha_
<
0
)
{
if
(
alpha_
<
=
0
)
{
MS_LOG
(
ERROR
)
<<
"MixUpBatch: alpha must be a positive floating value however it is: "
<<
alpha_
;
return
false
;
}
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
浏览文件 @
9b503e4f
...
...
@@ -59,7 +59,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling CutMixBatch."
);
RETURN_STATUS_UNEXPECTED
(
"You must
make sure images are HWC or CHW and
batch before calling CutMixBatch."
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Label's must be in one-hot format and in a batch"
);
...
...
@@ -139,10 +139,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Compute labels
for
(
int
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
}
}
}
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
浏览文件 @
9b503e4f
...
...
@@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling MixUpBatch"
);
RETURN_STATUS_UNEXPECTED
(
"You must
make sure images are HWC or CHW and
batch before calling MixUpBatch"
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Label's must be in one-hot format and in a batch"
);
...
...
@@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
RETURN_IF_NOT_OK
(
TypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
"float32"
)));
for
(
int64_t
i
=
0
;
i
<
label_shape
[
0
];
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
}
}
...
...
mindspore/dataset/transforms/vision/c_transforms.py
浏览文件 @
9b503e4f
...
...
@@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp):
class
RandomAffine
(
cde
.
RandomAffineOp
):
"""
Apply Random affine transformation to the input
PIL
image.
Apply Random affine transformation to the input image.
Args:
degrees (int or float or sequence): Range of the rotation degrees.
...
...
@@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp):
class
RandomColor
(
cde
.
RandomColorOp
):
"""
Adjust the color of the input image by a fixed or random degree.
This operation works only with 3-channel color images.
Args:
degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format. If min=max, then it is a
single fixed magnitude operation (default=(0.1,1.9)).
Works with 3-channel color images.
"""
@
check_positive_degrees
...
...
mindspore/dataset/transforms/vision/py_transforms.py
浏览文件 @
9b503e4f
...
...
@@ -1169,39 +1169,12 @@ class RandomAffine:
Returns:
img (PIL Image), Randomly affine transformed image.
"""
# rotation
angle
=
random
.
uniform
(
self
.
degrees
[
0
],
self
.
degrees
[
1
])
# translation
if
self
.
translate
is
not
None
:
max_dx
=
self
.
translate
[
0
]
*
img
.
size
[
0
]
max_dy
=
self
.
translate
[
1
]
*
img
.
size
[
1
]
translations
=
(
np
.
round
(
random
.
uniform
(
-
max_dx
,
max_dx
)),
np
.
round
(
random
.
uniform
(
-
max_dy
,
max_dy
)))
else
:
translations
=
(
0
,
0
)
# scale
if
self
.
scale_ranges
is
not
None
:
scale
=
random
.
uniform
(
self
.
scale_ranges
[
0
],
self
.
scale_ranges
[
1
])
else
:
scale
=
1.0
# shear
if
self
.
shear
is
not
None
:
if
len
(
self
.
shear
)
==
2
:
shear
=
[
random
.
uniform
(
self
.
shear
[
0
],
self
.
shear
[
1
]),
0.
]
elif
len
(
self
.
shear
)
==
4
:
shear
=
[
random
.
uniform
(
self
.
shear
[
0
],
self
.
shear
[
1
]),
random
.
uniform
(
self
.
shear
[
2
],
self
.
shear
[
3
])]
else
:
shear
=
0.0
return
util
.
random_affine
(
img
,
angle
,
translations
,
s
cale
,
shear
,
self
.
degrees
,
self
.
translate
,
s
elf
.
scale_ranges
,
s
elf
.
s
hear
,
self
.
resample
,
self
.
fill_value
)
...
...
mindspore/dataset/transforms/vision/py_transforms_util.py
浏览文件 @
9b503e4f
...
...
@@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
if
not
is_pil
(
img
):
raise
ValueError
(
"Input image should be a Pillow image."
)
# rotation
angle
=
random
.
uniform
(
angle
[
0
],
angle
[
1
])
# translation
if
translations
is
not
None
:
max_dx
=
translations
[
0
]
*
img
.
size
[
0
]
max_dy
=
translations
[
1
]
*
img
.
size
[
1
]
translations
=
(
np
.
round
(
random
.
uniform
(
-
max_dx
,
max_dx
)),
np
.
round
(
random
.
uniform
(
-
max_dy
,
max_dy
)))
else
:
translations
=
(
0
,
0
)
# scale
if
scale
is
not
None
:
scale
=
random
.
uniform
(
scale
[
0
],
scale
[
1
])
else
:
scale
=
1.0
# shear
if
shear
is
not
None
:
if
len
(
shear
)
==
2
:
shear
=
[
random
.
uniform
(
shear
[
0
],
shear
[
1
]),
0.
]
elif
len
(
shear
)
==
4
:
shear
=
[
random
.
uniform
(
shear
[
0
],
shear
[
1
]),
random
.
uniform
(
shear
[
2
],
shear
[
3
])]
else
:
shear
=
0.0
output_size
=
img
.
size
center
=
(
img
.
size
[
0
]
*
0.5
+
0.5
,
img
.
size
[
1
]
*
0.5
+
0.5
)
...
...
@@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
def
random_color
(
img
,
degrees
):
"""
Adjust the color of the input PIL image by a random degree.
...
...
@@ -1437,7 +1464,6 @@ def random_color(img, degrees):
def
random_sharpness
(
img
,
degrees
):
"""
Adjust the sharpness of the input PIL image by a random degree.
...
...
@@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees):
def
auto_contrast
(
img
,
cutoff
,
ignore
):
"""
Automatically maximize the contrast of the input PIL image.
...
...
@@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore):
def
invert_color
(
img
):
"""
Invert colors of input PIL image.
...
...
@@ -1498,7 +1522,6 @@ def invert_color(img):
def
equalize
(
img
):
"""
Equalize the histogram of input PIL image.
...
...
@@ -1517,7 +1540,6 @@ def equalize(img):
def
uniform_augment
(
img
,
transforms
,
num_ops
):
"""
Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assigns a probability to each transform for
...
...
mindspore/dataset/transforms/vision/validators.py
浏览文件 @
9b503e4f
...
...
@@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method):
[
image_batch_format
,
alpha
,
prob
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
type_check
(
image_batch_format
,
(
ImageBatchFormat
,),
"image_batch_format"
)
check_pos_float32
(
alpha
)
check_positive
(
alpha
,
"alpha"
)
check_value
(
prob
,
[
0
,
1
],
"prob"
)
return
method
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -68,6 +69,7 @@ def check_mix_up_batch_c(method):
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
alpha
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
check_positive
(
alpha
,
"alpha"
)
check_pos_float32
(
alpha
)
return
method
(
self
,
*
args
,
**
kwargs
)
...
...
tests/ut/cpp/dataset/c_api_transforms_test.cc
浏览文件 @
9b503e4f
...
...
@@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchFail3
)
{
// Must fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
0.0
,
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutOut
)
{
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
...
...
@@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
EXPECT_EQ
(
mixup_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchFail2
)
{
// This should fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
0.0
);
EXPECT_EQ
(
mixup_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchSuccess1
)
{
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
...
...
@@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
0.5
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
2.0
);
EXPECT_NE
(
mixup_batch_op
,
nullptr
);
// Create a Map operation on ds
...
...
tests/ut/python/dataset/test_cutmix_batch_op.py
浏览文件 @
9b503e4f
...
...
@@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN
=
False
...
...
@@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False):
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_success3
(
plot
=
False
):
"""
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
"""
logger
.
info
(
"test_cutmix_batch_success3"
)
ds_original
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# CutMix Images
data1
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data1
=
data1
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
images_cutmix
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_cutmix
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_cutmix
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_nhwc_md5
():
"""
Test CutMixBatch on a batch of HWC images with MD5:
...
...
@@ -185,7 +233,7 @@ def test_cutmix_batch_fail1():
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"You must
batch before calling CutMixBatch
"
error_message
=
"You must
make sure images are HWC or CHW and batch
"
assert
error_message
in
str
(
error
.
value
)
...
...
@@ -322,9 +370,28 @@ def test_cutmix_batch_fail7():
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail8
():
"""
Test CutMixBatch Fail 8
We expect this to fail because alpha is zero
"""
logger
.
info
(
"test_cutmix_batch_fail8"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
0.0
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
test_cutmix_batch_success1
(
plot
=
True
)
test_cutmix_batch_success2
(
plot
=
True
)
test_cutmix_batch_success3
(
plot
=
True
)
test_cutmix_batch_nchw_md5
()
test_cutmix_batch_nhwc_md5
()
test_cutmix_batch_fail1
()
...
...
@@ -334,3 +401,4 @@ if __name__ == "__main__":
test_cutmix_batch_fail5
()
test_cutmix_batch_fail6
()
test_cutmix_batch_fail7
()
test_cutmix_batch_fail8
()
tests/ut/python/dataset/test_mixup_op.py
浏览文件 @
9b503e4f
...
...
@@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN
=
False
...
...
@@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False):
def
test_mixup_batch_success2
(
plot
=
False
):
"""
Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2
"""
logger
.
info
(
"test_mixup_batch_success2"
)
# Original Images
ds_original
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
mixup_batch_op
=
vision
.
MixUpBatch
(
2.0
)
data1
=
data1
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
mixup_batch_op
)
images_mixup
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_mixup
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_mixup
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_mixup_batch_success3
(
plot
=
False
):
"""
Test MixUpBatch op without specified alpha parameter.
Alpha parameter will be selected by default in this case
"""
logger
.
info
(
"test_mixup_batch_success
2
"
)
logger
.
info
(
"test_mixup_batch_success
3
"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
...
...
@@ -169,7 +218,7 @@ def test_mixup_batch_fail1():
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
error_message
=
"You must
batch before calling MixUp
"
error_message
=
"You must
make sure images are HWC or CHW and batch
"
assert
error_message
in
str
(
error
.
value
)
...
...
@@ -207,6 +256,7 @@ def test_mixup_batch_fail3():
Test MixUpBatch op
We expect this to fail because label column is not passed to mixup_batch
"""
logger
.
info
(
"test_mixup_batch_fail3"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
...
...
@@ -237,11 +287,41 @@ def test_mixup_batch_fail3():
error_message
=
"Both images and labels columns are required"
assert
error_message
in
str
(
error
.
value
)
def
test_mixup_batch_fail4
():
"""
Test MixUpBatch Fail 2
We expect this to fail because alpha is zero
"""
logger
.
info
(
"test_mixup_batch_fail4"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
)
images_original
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
MixUpBatch
(
0.0
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
test_mixup_batch_success1
(
plot
=
True
)
test_mixup_batch_success2
(
plot
=
True
)
test_mixup_batch_success3
(
plot
=
True
)
test_mixup_batch_md5
()
test_mixup_batch_fail1
()
test_mixup_batch_fail2
()
test_mixup_batch_fail3
()
test_mixup_batch_fail4
()
tests/ut/python/dataset/test_random_affine.py
浏览文件 @
9b503e4f
...
...
@@ -27,6 +27,7 @@ GENERATE_GOLDEN = False
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
MNIST_DATA_DIR
=
"../data/dataset/testMnistData"
def
test_random_affine_op
(
plot
=
False
):
...
...
@@ -155,6 +156,24 @@ def test_random_affine_c_md5():
ds
.
config
.
set_num_parallel_workers
((
original_num_parallel_workers
))
def
test_random_affine_py_exception_non_pil_images
():
"""
Test RandomAffine: input img is ndarray and not PIL, expected to raise TypeError
"""
logger
.
info
(
"test_random_affine_exception_negative_degrees"
)
dataset
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
num_parallel_workers
=
3
)
try
:
transform
=
py_vision
.
ComposeOp
([
py_vision
.
ToTensor
(),
py_vision
.
RandomAffine
(
degrees
=
(
15
,
15
))])
dataset
=
dataset
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
(),
num_parallel_workers
=
3
,
python_multiprocessing
=
True
)
for
_
in
dataset
.
create_dict_iterator
():
break
except
RuntimeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Pillow image"
in
str
(
e
)
def
test_random_affine_exception_negative_degrees
():
"""
Test RandomAffine: input degrees in negative, expected to raise ValueError
...
...
@@ -289,6 +308,7 @@ if __name__ == "__main__":
test_random_affine_op_c
(
plot
=
True
)
test_random_affine_md5
()
test_random_affine_c_md5
()
test_random_affine_py_exception_non_pil_images
()
test_random_affine_exception_negative_degrees
()
test_random_affine_exception_translation_range
()
test_random_affine_exception_scale_value
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录