Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9b503e4f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b503e4f
编写于
8月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4955 Fixes for Dynamic Augmentation Ops
Merge pull request !4955 from MahdiRahmaniHanzaki/dynamic-ops-fix
上级
528fb810
a5f9b8f9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
285 addition
and
56 deletion
+285
-56
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+2
-2
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
...e/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
+12
-5
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
...re/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
+12
-5
mindspore/dataset/transforms/vision/c_transforms.py
mindspore/dataset/transforms/vision/c_transforms.py
+2
-2
mindspore/dataset/transforms/vision/py_transforms.py
mindspore/dataset/transforms/vision/py_transforms.py
+4
-31
mindspore/dataset/transforms/vision/py_transforms_util.py
mindspore/dataset/transforms/vision/py_transforms_util.py
+28
-6
mindspore/dataset/transforms/vision/validators.py
mindspore/dataset/transforms/vision/validators.py
+2
-0
tests/ut/cpp/dataset/c_api_transforms_test.cc
tests/ut/cpp/dataset/c_api_transforms_test.cc
+52
-2
tests/ut/python/dataset/test_cutmix_batch_op.py
tests/ut/python/dataset/test_cutmix_batch_op.py
+69
-1
tests/ut/python/dataset/test_mixup_op.py
tests/ut/python/dataset/test_mixup_op.py
+82
-2
tests/ut/python/dataset/test_random_affine.py
tests/ut/python/dataset/test_random_affine.py
+20
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
9b503e4f
...
@@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
...
@@ -382,7 +382,7 @@ CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format,
:
image_batch_format_
(
image_batch_format
),
alpha_
(
alpha
),
prob_
(
prob
)
{}
:
image_batch_format_
(
image_batch_format
),
alpha_
(
alpha
),
prob_
(
prob
)
{}
bool
CutMixBatchOperation
::
ValidateParams
()
{
bool
CutMixBatchOperation
::
ValidateParams
()
{
if
(
alpha_
<
0
)
{
if
(
alpha_
<
=
0
)
{
MS_LOG
(
ERROR
)
<<
"CutMixBatch: alpha cannot be negative."
;
MS_LOG
(
ERROR
)
<<
"CutMixBatch: alpha cannot be negative."
;
return
false
;
return
false
;
}
}
...
@@ -434,7 +434,7 @@ std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<H
...
@@ -434,7 +434,7 @@ std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<H
MixUpBatchOperation
::
MixUpBatchOperation
(
float
alpha
)
:
alpha_
(
alpha
)
{}
MixUpBatchOperation
::
MixUpBatchOperation
(
float
alpha
)
:
alpha_
(
alpha
)
{}
bool
MixUpBatchOperation
::
ValidateParams
()
{
bool
MixUpBatchOperation
::
ValidateParams
()
{
if
(
alpha_
<
0
)
{
if
(
alpha_
<
=
0
)
{
MS_LOG
(
ERROR
)
<<
"MixUpBatch: alpha must be a positive floating value however it is: "
<<
alpha_
;
MS_LOG
(
ERROR
)
<<
"MixUpBatch: alpha must be a positive floating value however it is: "
<<
alpha_
;
return
false
;
return
false
;
}
}
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
浏览文件 @
9b503e4f
...
@@ -59,7 +59,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
...
@@ -59,7 +59,7 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling CutMixBatch."
);
RETURN_STATUS_UNEXPECTED
(
"You must
make sure images are HWC or CHW and
batch before calling CutMixBatch."
);
}
}
if
(
label_shape
.
size
()
!=
2
)
{
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Label's must be in one-hot format and in a batch"
);
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Label's must be in one-hot format and in a batch"
);
...
@@ -139,10 +139,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
...
@@ -139,10 +139,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Compute labels
// Compute labels
for
(
int
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
for
(
int
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
uint64_t
first_value
,
second_value
;
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
}
}
}
}
}
}
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
浏览文件 @
9b503e4f
...
@@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
...
@@ -38,7 +38,7 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling MixUpBatch"
);
RETURN_STATUS_UNEXPECTED
(
"You must
make sure images are HWC or CHW and
batch before calling MixUpBatch"
);
}
}
if
(
label_shape
.
size
()
!=
2
)
{
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Label's must be in one-hot format and in a batch"
);
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Label's must be in one-hot format and in a batch"
);
...
@@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
...
@@ -68,10 +68,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
RETURN_IF_NOT_OK
(
TypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
"float32"
)));
RETURN_IF_NOT_OK
(
TypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
"float32"
)));
for
(
int64_t
i
=
0
;
i
<
label_shape
[
0
];
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
label_shape
[
0
];
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
for
(
int64_t
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
uint64_t
first_value
,
second_value
;
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
}
}
}
}
...
...
mindspore/dataset/transforms/vision/c_transforms.py
浏览文件 @
9b503e4f
...
@@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp):
...
@@ -231,7 +231,7 @@ class Normalize(cde.NormalizeOp):
class
RandomAffine
(
cde
.
RandomAffineOp
):
class
RandomAffine
(
cde
.
RandomAffineOp
):
"""
"""
Apply Random affine transformation to the input
PIL
image.
Apply Random affine transformation to the input image.
Args:
Args:
degrees (int or float or sequence): Range of the rotation degrees.
degrees (int or float or sequence): Range of the rotation degrees.
...
@@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp):
...
@@ -681,12 +681,12 @@ class CenterCrop(cde.CenterCropOp):
class
RandomColor
(
cde
.
RandomColorOp
):
class
RandomColor
(
cde
.
RandomColorOp
):
"""
"""
Adjust the color of the input image by a fixed or random degree.
Adjust the color of the input image by a fixed or random degree.
This operation works only with 3-channel color images.
Args:
Args:
degrees (sequence): Range of random color adjustment degrees.
degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format. If min=max, then it is a
It should be in (min, max) format. If min=max, then it is a
single fixed magnitude operation (default=(0.1,1.9)).
single fixed magnitude operation (default=(0.1,1.9)).
Works with 3-channel color images.
"""
"""
@
check_positive_degrees
@
check_positive_degrees
...
...
mindspore/dataset/transforms/vision/py_transforms.py
浏览文件 @
9b503e4f
...
@@ -1169,39 +1169,12 @@ class RandomAffine:
...
@@ -1169,39 +1169,12 @@ class RandomAffine:
Returns:
Returns:
img (PIL Image), Randomly affine transformed image.
img (PIL Image), Randomly affine transformed image.
"""
"""
# rotation
angle
=
random
.
uniform
(
self
.
degrees
[
0
],
self
.
degrees
[
1
])
# translation
if
self
.
translate
is
not
None
:
max_dx
=
self
.
translate
[
0
]
*
img
.
size
[
0
]
max_dy
=
self
.
translate
[
1
]
*
img
.
size
[
1
]
translations
=
(
np
.
round
(
random
.
uniform
(
-
max_dx
,
max_dx
)),
np
.
round
(
random
.
uniform
(
-
max_dy
,
max_dy
)))
else
:
translations
=
(
0
,
0
)
# scale
if
self
.
scale_ranges
is
not
None
:
scale
=
random
.
uniform
(
self
.
scale_ranges
[
0
],
self
.
scale_ranges
[
1
])
else
:
scale
=
1.0
# shear
if
self
.
shear
is
not
None
:
if
len
(
self
.
shear
)
==
2
:
shear
=
[
random
.
uniform
(
self
.
shear
[
0
],
self
.
shear
[
1
]),
0.
]
elif
len
(
self
.
shear
)
==
4
:
shear
=
[
random
.
uniform
(
self
.
shear
[
0
],
self
.
shear
[
1
]),
random
.
uniform
(
self
.
shear
[
2
],
self
.
shear
[
3
])]
else
:
shear
=
0.0
return
util
.
random_affine
(
img
,
return
util
.
random_affine
(
img
,
angle
,
self
.
degrees
,
translations
,
self
.
translate
,
s
cale
,
s
elf
.
scale_ranges
,
shear
,
s
elf
.
s
hear
,
self
.
resample
,
self
.
resample
,
self
.
fill_value
)
self
.
fill_value
)
...
...
mindspore/dataset/transforms/vision/py_transforms_util.py
浏览文件 @
9b503e4f
...
@@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
...
@@ -1153,6 +1153,34 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
if
not
is_pil
(
img
):
if
not
is_pil
(
img
):
raise
ValueError
(
"Input image should be a Pillow image."
)
raise
ValueError
(
"Input image should be a Pillow image."
)
# rotation
angle
=
random
.
uniform
(
angle
[
0
],
angle
[
1
])
# translation
if
translations
is
not
None
:
max_dx
=
translations
[
0
]
*
img
.
size
[
0
]
max_dy
=
translations
[
1
]
*
img
.
size
[
1
]
translations
=
(
np
.
round
(
random
.
uniform
(
-
max_dx
,
max_dx
)),
np
.
round
(
random
.
uniform
(
-
max_dy
,
max_dy
)))
else
:
translations
=
(
0
,
0
)
# scale
if
scale
is
not
None
:
scale
=
random
.
uniform
(
scale
[
0
],
scale
[
1
])
else
:
scale
=
1.0
# shear
if
shear
is
not
None
:
if
len
(
shear
)
==
2
:
shear
=
[
random
.
uniform
(
shear
[
0
],
shear
[
1
]),
0.
]
elif
len
(
shear
)
==
4
:
shear
=
[
random
.
uniform
(
shear
[
0
],
shear
[
1
]),
random
.
uniform
(
shear
[
2
],
shear
[
3
])]
else
:
shear
=
0.0
output_size
=
img
.
size
output_size
=
img
.
size
center
=
(
img
.
size
[
0
]
*
0.5
+
0.5
,
img
.
size
[
1
]
*
0.5
+
0.5
)
center
=
(
img
.
size
[
0
]
*
0.5
+
0.5
,
img
.
size
[
1
]
*
0.5
+
0.5
)
...
@@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
...
@@ -1416,7 +1444,6 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
def
random_color
(
img
,
degrees
):
def
random_color
(
img
,
degrees
):
"""
"""
Adjust the color of the input PIL image by a random degree.
Adjust the color of the input PIL image by a random degree.
...
@@ -1437,7 +1464,6 @@ def random_color(img, degrees):
...
@@ -1437,7 +1464,6 @@ def random_color(img, degrees):
def
random_sharpness
(
img
,
degrees
):
def
random_sharpness
(
img
,
degrees
):
"""
"""
Adjust the sharpness of the input PIL image by a random degree.
Adjust the sharpness of the input PIL image by a random degree.
...
@@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees):
...
@@ -1458,7 +1484,6 @@ def random_sharpness(img, degrees):
def
auto_contrast
(
img
,
cutoff
,
ignore
):
def
auto_contrast
(
img
,
cutoff
,
ignore
):
"""
"""
Automatically maximize the contrast of the input PIL image.
Automatically maximize the contrast of the input PIL image.
...
@@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore):
...
@@ -1479,7 +1504,6 @@ def auto_contrast(img, cutoff, ignore):
def
invert_color
(
img
):
def
invert_color
(
img
):
"""
"""
Invert colors of input PIL image.
Invert colors of input PIL image.
...
@@ -1498,7 +1522,6 @@ def invert_color(img):
...
@@ -1498,7 +1522,6 @@ def invert_color(img):
def
equalize
(
img
):
def
equalize
(
img
):
"""
"""
Equalize the histogram of input PIL image.
Equalize the histogram of input PIL image.
...
@@ -1517,7 +1540,6 @@ def equalize(img):
...
@@ -1517,7 +1540,6 @@ def equalize(img):
def
uniform_augment
(
img
,
transforms
,
num_ops
):
def
uniform_augment
(
img
,
transforms
,
num_ops
):
"""
"""
Uniformly select and apply a number of transforms sequentially from
Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assigns a probability to each transform for
a list of transforms. Randomly assigns a probability to each transform for
...
...
mindspore/dataset/transforms/vision/validators.py
浏览文件 @
9b503e4f
...
@@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method):
...
@@ -45,6 +45,7 @@ def check_cut_mix_batch_c(method):
[
image_batch_format
,
alpha
,
prob
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
[
image_batch_format
,
alpha
,
prob
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
type_check
(
image_batch_format
,
(
ImageBatchFormat
,),
"image_batch_format"
)
type_check
(
image_batch_format
,
(
ImageBatchFormat
,),
"image_batch_format"
)
check_pos_float32
(
alpha
)
check_pos_float32
(
alpha
)
check_positive
(
alpha
,
"alpha"
)
check_value
(
prob
,
[
0
,
1
],
"prob"
)
check_value
(
prob
,
[
0
,
1
],
"prob"
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
method
(
self
,
*
args
,
**
kwargs
)
...
@@ -68,6 +69,7 @@ def check_mix_up_batch_c(method):
...
@@ -68,6 +69,7 @@ def check_mix_up_batch_c(method):
@
wraps
(
method
)
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
alpha
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
[
alpha
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
check_positive
(
alpha
,
"alpha"
)
check_pos_float32
(
alpha
)
check_pos_float32
(
alpha
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
method
(
self
,
*
args
,
**
kwargs
)
...
...
tests/ut/cpp/dataset/c_api_transforms_test.cc
浏览文件 @
9b503e4f
...
@@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
...
@@ -191,11 +191,37 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchFail3
)
{
// Must fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
0.0
,
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutOut
)
{
TEST_F
(
MindDataTestPipeline
,
TestCutOut
)
{
// Create an ImageFolder Dataset
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
...
@@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
...
@@ -365,6 +391,30 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
EXPECT_EQ
(
mixup_batch_op
,
nullptr
);
EXPECT_EQ
(
mixup_batch_op
,
nullptr
);
}
}
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchFail2
)
{
// This should fail because alpha can't be zero
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
0.0
);
EXPECT_EQ
(
mixup_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchSuccess1
)
{
TEST_F
(
MindDataTestPipeline
,
TestMixUpBatchSuccess1
)
{
// Create a Cifar10 Dataset
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
...
@@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
...
@@ -384,7 +434,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
ds
=
ds
->
Map
({
one_hot_op
},
{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
0.5
);
std
::
shared_ptr
<
TensorOperation
>
mixup_batch_op
=
vision
::
MixUpBatch
(
2.0
);
EXPECT_NE
(
mixup_batch_op
,
nullptr
);
EXPECT_NE
(
mixup_batch_op
,
nullptr
);
// Create a Map operation on ds
// Create a Map operation on ds
...
...
tests/ut/python/dataset/test_cutmix_batch_op.py
浏览文件 @
9b503e4f
...
@@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
...
@@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers
config_get_set_num_parallel_workers
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
...
@@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False):
...
@@ -114,6 +115,53 @@ def test_cutmix_batch_success2(plot=False):
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_success3
(
plot
=
False
):
"""
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
"""
logger
.
info
(
"test_cutmix_batch_success3"
)
ds_original
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# CutMix Images
data1
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data1
=
data1
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
images_cutmix
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_cutmix
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_cutmix
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_nhwc_md5
():
def
test_cutmix_batch_nhwc_md5
():
"""
"""
Test CutMixBatch on a batch of HWC images with MD5:
Test CutMixBatch on a batch of HWC images with MD5:
...
@@ -185,7 +233,7 @@ def test_cutmix_batch_fail1():
...
@@ -185,7 +233,7 @@ def test_cutmix_batch_fail1():
images_cutmix
=
image
images_cutmix
=
image
else
:
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"You must
batch before calling CutMixBatch
"
error_message
=
"You must
make sure images are HWC or CHW and batch
"
assert
error_message
in
str
(
error
.
value
)
assert
error_message
in
str
(
error
.
value
)
...
@@ -322,9 +370,28 @@ def test_cutmix_batch_fail7():
...
@@ -322,9 +370,28 @@ def test_cutmix_batch_fail7():
assert
error_message
in
str
(
error
.
value
)
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail8
():
"""
Test CutMixBatch Fail 8
We expect this to fail because alpha is zero
"""
logger
.
info
(
"test_cutmix_batch_fail8"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
0.0
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cutmix_batch_success1
(
plot
=
True
)
test_cutmix_batch_success1
(
plot
=
True
)
test_cutmix_batch_success2
(
plot
=
True
)
test_cutmix_batch_success2
(
plot
=
True
)
test_cutmix_batch_success3
(
plot
=
True
)
test_cutmix_batch_nchw_md5
()
test_cutmix_batch_nchw_md5
()
test_cutmix_batch_nhwc_md5
()
test_cutmix_batch_nhwc_md5
()
test_cutmix_batch_fail1
()
test_cutmix_batch_fail1
()
...
@@ -334,3 +401,4 @@ if __name__ == "__main__":
...
@@ -334,3 +401,4 @@ if __name__ == "__main__":
test_cutmix_batch_fail5
()
test_cutmix_batch_fail5
()
test_cutmix_batch_fail6
()
test_cutmix_batch_fail6
()
test_cutmix_batch_fail7
()
test_cutmix_batch_fail7
()
test_cutmix_batch_fail8
()
tests/ut/python/dataset/test_mixup_op.py
浏览文件 @
9b503e4f
...
@@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
...
@@ -25,6 +25,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
config_get_set_num_parallel_workers
config_get_set_num_parallel_workers
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
GENERATE_GOLDEN
=
False
GENERATE_GOLDEN
=
False
...
@@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False):
...
@@ -71,11 +72,59 @@ def test_mixup_batch_success1(plot=False):
def
test_mixup_batch_success2
(
plot
=
False
):
def
test_mixup_batch_success2
(
plot
=
False
):
"""
Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2
"""
logger
.
info
(
"test_mixup_batch_success2"
)
# Original Images
ds_original
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
ImageFolderDatasetV2
(
dataset_dir
=
DATA_DIR2
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
mixup_batch_op
=
vision
.
MixUpBatch
(
2.0
)
data1
=
data1
.
batch
(
4
,
pad_info
=
{},
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
mixup_batch_op
)
images_mixup
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_mixup
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_mixup
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_mixup_batch_success3
(
plot
=
False
):
"""
"""
Test MixUpBatch op without specified alpha parameter.
Test MixUpBatch op without specified alpha parameter.
Alpha parameter will be selected by default in this case
Alpha parameter will be selected by default in this case
"""
"""
logger
.
info
(
"test_mixup_batch_success
2
"
)
logger
.
info
(
"test_mixup_batch_success
3
"
)
# Original Images
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
...
@@ -169,7 +218,7 @@ def test_mixup_batch_fail1():
...
@@ -169,7 +218,7 @@ def test_mixup_batch_fail1():
images_mixup
=
image
images_mixup
=
image
else
:
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
error_message
=
"You must
batch before calling MixUp
"
error_message
=
"You must
make sure images are HWC or CHW and batch
"
assert
error_message
in
str
(
error
.
value
)
assert
error_message
in
str
(
error
.
value
)
...
@@ -207,6 +256,7 @@ def test_mixup_batch_fail3():
...
@@ -207,6 +256,7 @@ def test_mixup_batch_fail3():
Test MixUpBatch op
Test MixUpBatch op
We expect this to fail because label column is not passed to mixup_batch
We expect this to fail because label column is not passed to mixup_batch
"""
"""
logger
.
info
(
"test_mixup_batch_fail3"
)
# Original Images
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
...
@@ -237,11 +287,41 @@ def test_mixup_batch_fail3():
...
@@ -237,11 +287,41 @@ def test_mixup_batch_fail3():
error_message
=
"Both images and labels columns are required"
error_message
=
"Both images and labels columns are required"
assert
error_message
in
str
(
error
.
value
)
assert
error_message
in
str
(
error
.
value
)
def
test_mixup_batch_fail4
():
"""
Test MixUpBatch Fail 2
We expect this to fail because alpha is zero
"""
logger
.
info
(
"test_mixup_batch_fail4"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
)
images_original
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
MixUpBatch
(
0.0
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_mixup_batch_success1
(
plot
=
True
)
test_mixup_batch_success1
(
plot
=
True
)
test_mixup_batch_success2
(
plot
=
True
)
test_mixup_batch_success2
(
plot
=
True
)
test_mixup_batch_success3
(
plot
=
True
)
test_mixup_batch_md5
()
test_mixup_batch_md5
()
test_mixup_batch_fail1
()
test_mixup_batch_fail1
()
test_mixup_batch_fail2
()
test_mixup_batch_fail2
()
test_mixup_batch_fail3
()
test_mixup_batch_fail3
()
test_mixup_batch_fail4
()
tests/ut/python/dataset/test_random_affine.py
浏览文件 @
9b503e4f
...
@@ -27,6 +27,7 @@ GENERATE_GOLDEN = False
...
@@ -27,6 +27,7 @@ GENERATE_GOLDEN = False
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
MNIST_DATA_DIR
=
"../data/dataset/testMnistData"
def
test_random_affine_op
(
plot
=
False
):
def
test_random_affine_op
(
plot
=
False
):
...
@@ -155,6 +156,24 @@ def test_random_affine_c_md5():
...
@@ -155,6 +156,24 @@ def test_random_affine_c_md5():
ds
.
config
.
set_num_parallel_workers
((
original_num_parallel_workers
))
ds
.
config
.
set_num_parallel_workers
((
original_num_parallel_workers
))
def
test_random_affine_py_exception_non_pil_images
():
"""
Test RandomAffine: input img is ndarray and not PIL, expected to raise TypeError
"""
logger
.
info
(
"test_random_affine_exception_negative_degrees"
)
dataset
=
ds
.
MnistDataset
(
MNIST_DATA_DIR
,
num_parallel_workers
=
3
)
try
:
transform
=
py_vision
.
ComposeOp
([
py_vision
.
ToTensor
(),
py_vision
.
RandomAffine
(
degrees
=
(
15
,
15
))])
dataset
=
dataset
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
(),
num_parallel_workers
=
3
,
python_multiprocessing
=
True
)
for
_
in
dataset
.
create_dict_iterator
():
break
except
RuntimeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Pillow image"
in
str
(
e
)
def
test_random_affine_exception_negative_degrees
():
def
test_random_affine_exception_negative_degrees
():
"""
"""
Test RandomAffine: input degrees in negative, expected to raise ValueError
Test RandomAffine: input degrees in negative, expected to raise ValueError
...
@@ -289,6 +308,7 @@ if __name__ == "__main__":
...
@@ -289,6 +308,7 @@ if __name__ == "__main__":
test_random_affine_op_c
(
plot
=
True
)
test_random_affine_op_c
(
plot
=
True
)
test_random_affine_md5
()
test_random_affine_md5
()
test_random_affine_c_md5
()
test_random_affine_c_md5
()
test_random_affine_py_exception_non_pil_images
()
test_random_affine_exception_negative_degrees
()
test_random_affine_exception_negative_degrees
()
test_random_affine_exception_translation_range
()
test_random_affine_exception_translation_range
()
test_random_affine_exception_scale_value
()
test_random_affine_exception_scale_value
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录