Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7ad7024c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ad7024c
编写于
8月 26, 2020
作者:
M
mahdi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed 2D one-hot label problems in CutMix and MixUp
上级
43a61e46
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
203 addition
and
36 deletion
+203
-36
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
...e/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
+33
-17
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
...re/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
+31
-16
tests/ut/python/dataset/test_cutmix_batch_op.py
tests/ut/python/dataset/test_cutmix_batch_op.py
+50
-2
tests/ut/python/dataset/test_mixup_op.py
tests/ut/python/dataset/test_mixup_op.py
+89
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
浏览文件 @
7ad7024c
...
...
@@ -59,10 +59,17 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must make sure images are HWC or CHW and batched before calling CutMixBatch."
);
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: You must make sure images are HWC or CHW and batched before calling CutMixBatch."
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Label's must be in one-hot format and in a batch."
);
if
(
!
input
.
at
(
1
)
->
type
().
IsInt
())
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Wrong labels type. The second column (labels) must only include int types."
);
}
if
(
label_shape
.
size
()
!=
2
&&
label_shape
.
size
()
!=
3
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Wrong labels shape. The second column (labels) must have a shape of NC or NLC where N is the batch "
"size, L is the number of labels in each row, "
"and C is the number of classes. labels must be in one-hot format and in a batch."
);
}
if
((
image_shape
[
1
]
!=
1
&&
image_shape
[
1
]
!=
3
)
&&
image_batch_format_
==
ImageBatchFormat
::
kNCHW
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Image doesn't match the given image format."
);
...
...
@@ -84,10 +91,12 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Tensor holding the output labels
std
::
shared_ptr
<
Tensor
>
out_labels
;
RETURN_IF_NOT_OK
(
T
ensor
::
CreateEmpty
(
TensorShape
(
label_shape
),
DataType
(
DataType
::
DE_FLOAT32
),
&
out_labels
));
RETURN_IF_NOT_OK
(
T
ypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
DataType
::
DE_FLOAT32
)
));
int64_t
row_labels
=
label_shape
.
size
()
==
3
?
label_shape
[
1
]
:
1
;
int64_t
num_classes
=
label_shape
.
size
()
==
3
?
label_shape
[
2
]
:
label_shape
[
1
];
// Compute labels and images
for
(
int
i
=
0
;
i
<
image_shape
[
0
];
i
++
)
{
for
(
int
64_t
i
=
0
;
i
<
image_shape
[
0
];
i
++
)
{
// Calculating lambda
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
...
...
@@ -138,22 +147,29 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
}
// Compute labels
for
(
int
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
for
(
int64_t
j
=
0
;
j
<
row_labels
;
j
++
)
{
for
(
int64_t
k
=
0
;
k
<
num_classes
;
k
++
)
{
std
::
vector
<
int64_t
>
first_index
=
label_shape
.
size
()
==
3
?
std
::
vector
{
i
,
j
,
k
}
:
std
::
vector
{
i
,
k
};
std
::
vector
<
int64_t
>
second_index
=
label_shape
.
size
()
==
3
?
std
::
vector
{
rand_indx
[
i
],
j
,
k
}
:
std
::
vector
{
rand_indx
[
i
],
k
};
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
first_index
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
second_index
));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
(
first_index
,
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
first_index
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
second_index
));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
(
first_index
,
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
}
}
}
}
std
::
shared_ptr
<
Tensor
>
out_images
;
RETURN_IF_NOT_OK
(
TensorVectorToBatchTensor
(
images
,
&
out_images
));
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
浏览文件 @
7ad7024c
...
...
@@ -38,10 +38,17 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must make sure images are HWC or CHW and batched before calling MixUpBatch."
);
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch:You must make sure images are HWC or CHW and batched before calling MixUpBatch."
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Label's must be in one-hot format and in a batch."
);
if
(
!
input
.
at
(
1
)
->
type
().
IsInt
())
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Wrong labels type. The second column (labels) must only include int types."
);
}
if
(
label_shape
.
size
()
!=
2
&&
label_shape
.
size
()
!=
3
)
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Wrong labels shape. The second column (labels) must have a shape of NC or NLC where N is the batch "
"size, L is the number of labels in each row, "
"and C is the number of classes. labels must be in one-hot format and in a batch."
);
}
if
((
image_shape
[
1
]
!=
1
&&
image_shape
[
1
]
!=
3
)
&&
(
image_shape
[
3
]
!=
1
&&
image_shape
[
3
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Images must be in the shape of HWC or CHW."
);
...
...
@@ -65,23 +72,31 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
// Compute labels
std
::
shared_ptr
<
Tensor
>
out_labels
;
RETURN_IF_NOT_OK
(
TypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
"float32"
)));
RETURN_IF_NOT_OK
(
TypeCast
(
std
::
move
(
input
.
at
(
1
)),
&
out_labels
,
DataType
(
DataType
::
DE_FLOAT32
)));
int64_t
row_labels
=
label_shape
.
size
()
==
3
?
label_shape
[
1
]
:
1
;
int64_t
num_classes
=
label_shape
.
size
()
==
3
?
label_shape
[
2
]
:
label_shape
[
1
];
for
(
int64_t
i
=
0
;
i
<
label_shape
[
0
];
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
for
(
int64_t
j
=
0
;
j
<
row_labels
;
j
++
)
{
for
(
int64_t
k
=
0
;
k
<
num_classes
;
k
++
)
{
std
::
vector
<
int64_t
>
first_index
=
label_shape
.
size
()
==
3
?
std
::
vector
{
i
,
j
,
k
}
:
std
::
vector
{
i
,
k
};
std
::
vector
<
int64_t
>
second_index
=
label_shape
.
size
()
==
3
?
std
::
vector
{
rand_indx
[
i
],
j
,
k
}
:
std
::
vector
{
rand_indx
[
i
],
k
};
if
(
input
.
at
(
1
)
->
type
().
IsSignedInt
())
{
int64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
first_index
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
second_index
));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
(
first_index
,
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
else
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
first_index
));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
second_index
));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
(
first_index
,
lam
*
first_value
+
(
1
-
lam
)
*
second_value
));
}
}
}
}
// Compute images
for
(
int64_t
i
=
0
;
i
<
images
.
size
();
i
++
)
{
TensorShape
remaining
({
-
1
});
...
...
tests/ut/python/dataset/test_cutmix_batch_op.py
浏览文件 @
7ad7024c
...
...
@@ -27,6 +27,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
DATA_DIR3
=
"../data/dataset/testCelebAData/"
GENERATE_GOLDEN
=
False
...
...
@@ -36,7 +37,6 @@ def test_cutmix_batch_success1(plot=False):
Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
"""
logger
.
info
(
"test_cutmix_batch_success1"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
...
...
@@ -164,6 +164,53 @@ def test_cutmix_batch_success3(plot=False):
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_success4
(
plot
=
False
):
"""
Test CutMixBatch on a dataset where OneHot returns a 2D vector
"""
logger
.
info
(
"test_cutmix_batch_success4"
)
ds_original
=
ds
.
CelebADataset
(
DATA_DIR3
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
2
,
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# CutMix Images
data1
=
ds
.
CelebADataset
(
dataset_dir
=
DATA_DIR3
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
100
)
data1
=
data1
.
map
(
input_columns
=
[
"attr"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
0.5
,
0.9
)
data1
=
data1
.
batch
(
2
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"attr"
],
operations
=
cutmix_batch_op
)
images_cutmix
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_cutmix
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_cutmix
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_nhwc_md5
():
"""
Test CutMixBatch on a batch of HWC images with MD5:
...
...
@@ -368,7 +415,7 @@ def test_cutmix_batch_fail7():
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"CutMixBatch:
Label's must be in one-hot format and in a batch
"
error_message
=
"CutMixBatch:
Wrong labels shape. The second column (labels) must have a shape of NC or NLC
"
assert
error_message
in
str
(
error
.
value
)
...
...
@@ -394,6 +441,7 @@ if __name__ == "__main__":
test_cutmix_batch_success1
(
plot
=
True
)
test_cutmix_batch_success2
(
plot
=
True
)
test_cutmix_batch_success3
(
plot
=
True
)
test_cutmix_batch_success4
(
plot
=
True
)
test_cutmix_batch_nchw_md5
()
test_cutmix_batch_nhwc_md5
()
test_cutmix_batch_fail1
()
...
...
tests/ut/python/dataset/test_mixup_op.py
浏览文件 @
7ad7024c
...
...
@@ -26,6 +26,7 @@ from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_se
DATA_DIR
=
"../data/dataset/testCifar10Data"
DATA_DIR2
=
"../data/dataset/testImageNetData2/train/"
DATA_DIR3
=
"../data/dataset/testCelebAData/"
GENERATE_GOLDEN
=
False
...
...
@@ -162,6 +163,55 @@ def test_mixup_batch_success3(plot=False):
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_mixup_batch_success4
(
plot
=
False
):
"""
Test MixUpBatch op on a dataset where OneHot returns a 2D vector.
Alpha parameter will be selected by default in this case
"""
logger
.
info
(
"test_mixup_batch_success4"
)
# Original Images
ds_original
=
ds
.
CelebADataset
(
DATA_DIR3
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
ds_original
=
ds_original
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
ds_original
=
ds_original
.
batch
(
2
,
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
CelebADataset
(
DATA_DIR3
,
shuffle
=
False
)
decode_op
=
vision
.
Decode
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
decode_op
])
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
100
)
data1
=
data1
.
map
(
input_columns
=
[
"attr"
],
operations
=
one_hot_op
)
mixup_batch_op
=
vision
.
MixUpBatch
()
data1
=
data1
.
batch
(
2
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"attr"
],
operations
=
mixup_batch_op
)
images_mixup
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_mixup
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_mixup
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_mixup_batch_md5
():
"""
Test MixUpBatch with MD5:
...
...
@@ -218,7 +268,7 @@ def test_mixup_batch_fail1():
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
error_message
=
"You must make sure images are HWC or CHW and batch"
error_message
=
"You must make sure images are HWC or CHW and batch
ed
"
assert
error_message
in
str
(
error
.
value
)
...
...
@@ -316,12 +366,50 @@ def test_mixup_batch_fail4():
assert
error_message
in
str
(
error
.
value
)
def
test_mixup_batch_fail5
():
"""
Test MixUpBatch Fail 5
We expect this to fail because labels are not OntHot encoded
"""
logger
.
info
(
"test_mixup_batch_fail5"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
)
images_original
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# MixUp Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
mixup_batch_op
=
vision
.
MixUpBatch
()
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
mixup_batch_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error
:
images_mixup
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_mixup
=
image
else
:
images_mixup
=
np
.
append
(
images_mixup
,
image
,
axis
=
0
)
error_message
=
"MixUpBatch: Wrong labels shape. The second column (labels) must have a shape of NC or NLC"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
test_mixup_batch_success1
(
plot
=
True
)
test_mixup_batch_success2
(
plot
=
True
)
test_mixup_batch_success3
(
plot
=
True
)
test_mixup_batch_success4
(
plot
=
True
)
test_mixup_batch_md5
()
test_mixup_batch_fail1
()
test_mixup_batch_fail2
()
test_mixup_batch_fail3
()
test_mixup_batch_fail4
()
test_mixup_batch_fail5
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录