Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
113ff6ca
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
113ff6ca
编写于
8月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4367 CutMixBatch Augmentation Op
Merge pull request !4367 from MahdiRahmaniHanzaki/cutmix
上级
c01f0b66
3ecc53fb
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
1102 addition
and
11 deletion
+1102
-11
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc
...data/dataset/api/python/bindings/dataset/core/bindings.cc
+7
-0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc
...set/api/python/bindings/dataset/kernels/image/bindings.cc
+8
-0
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+32
-0
mindspore/ccsrc/minddata/dataset/core/constants.h
mindspore/ccsrc/minddata/dataset/core/constants.h
+6
-0
mindspore/ccsrc/minddata/dataset/include/transforms.h
mindspore/ccsrc/minddata/dataset/include/transforms.h
+28
-0
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
+36
-3
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
+16
-3
mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt
...spore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt
+1
-0
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
...e/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
+166
-0
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h
...re/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h
+52
-0
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
...spore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
+56
-0
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h
+13
-0
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
...re/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
+4
-2
mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h
mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h
+1
-0
mindspore/dataset/transforms/vision/c_transforms.py
mindspore/dataset/transforms/vision/c_transforms.py
+32
-2
mindspore/dataset/transforms/vision/utils.py
mindspore/dataset/transforms/vision/utils.py
+6
-0
mindspore/dataset/transforms/vision/validators.py
mindspore/dataset/transforms/vision/validators.py
+15
-1
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/c_api_transforms_test.cc
tests/ut/cpp/dataset/c_api_transforms_test.cc
+171
-0
tests/ut/cpp/dataset/cutmix_batch_op_test.cc
tests/ut/cpp/dataset/cutmix_batch_op_test.cc
+115
-0
tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz
tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz
+0
-0
tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz
tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz
+0
-0
tests/ut/python/dataset/test_cutmix_batch_op.py
tests/ut/python/dataset/test_cutmix_batch_op.py
+336
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc
浏览文件 @
113ff6ca
...
...
@@ -110,5 +110,12 @@ PYBIND_REGISTER(InterpolationMode, 0, ([](const py::module *m) {
.
export_values
();
}));
PYBIND_REGISTER
(
ImageBatchFormat
,
0
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
ImageBatchFormat
>
(
*
m
,
"ImageBatchFormat"
,
py
::
arithmetic
())
.
value
(
"DE_IMAGE_BATCH_FORMAT_NHWC"
,
ImageBatchFormat
::
kNHWC
)
.
value
(
"DE_IMAGE_BATCH_FORMAT_NCHW"
,
ImageBatchFormat
::
kNCHW
)
.
export_values
();
}));
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc
浏览文件 @
113ff6ca
...
...
@@ -22,6 +22,7 @@
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
#include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
#include "minddata/dataset/kernels/image/center_crop_op.h"
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/equalize_op.h"
...
...
@@ -105,6 +106,13 @@ PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) {
.
def
(
py
::
init
<
float
>
(),
py
::
arg
(
"alpha"
));
}));
PYBIND_REGISTER
(
CutMixBatchOp
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
CutMixBatchOp
,
TensorOp
,
std
::
shared_ptr
<
CutMixBatchOp
>>
(
*
m
,
"CutMixBatchOp"
,
"Tensor operation to cutmix a batch of images"
)
.
def
(
py
::
init
<
ImageBatchFormat
,
float
,
float
>
(),
py
::
arg
(
"image_batch_format"
),
py
::
arg
(
"alpha"
),
py
::
arg
(
"prob"
));
}));
PYBIND_REGISTER
(
ResizeOp
,
1
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
ResizeOp
,
TensorOp
,
std
::
shared_ptr
<
ResizeOp
>>
(
*
m
,
"ResizeOp"
,
"Tensor operation to resize an image. Takes height, width and mode"
)
...
...
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
113ff6ca
...
...
@@ -19,6 +19,7 @@
#include "minddata/dataset/kernels/image/center_crop_op.h"
#include "minddata/dataset/kernels/image/crop_op.h"
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
...
...
@@ -70,6 +71,16 @@ std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vecto
return
op
;
}
// Function to create CutMixBatchOperation.
std
::
shared_ptr
<
CutMixBatchOperation
>
CutMixBatch
(
ImageBatchFormat
image_batch_format
,
float
alpha
,
float
prob
)
{
auto
op
=
std
::
make_shared
<
CutMixBatchOperation
>
(
image_batch_format
,
alpha
,
prob
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
}
return
op
;
}
// Function to create CutOutOp.
std
::
shared_ptr
<
CutOutOperation
>
CutOut
(
int32_t
length
,
int32_t
num_patches
)
{
auto
op
=
std
::
make_shared
<
CutOutOperation
>
(
length
,
num_patches
);
...
...
@@ -355,6 +366,27 @@ std::shared_ptr<TensorOp> CropOperation::Build() {
return
tensor_op
;
}
// CutMixBatchOperation
CutMixBatchOperation
::
CutMixBatchOperation
(
ImageBatchFormat
image_batch_format
,
float
alpha
,
float
prob
)
:
image_batch_format_
(
image_batch_format
),
alpha_
(
alpha
),
prob_
(
prob
)
{}
bool
CutMixBatchOperation
::
ValidateParams
()
{
if
(
alpha_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CutMixBatch: alpha cannot be negative."
;
return
false
;
}
if
(
prob_
<
0
||
prob_
>
1
)
{
MS_LOG
(
ERROR
)
<<
"CutMixBatch: Probability has to be between 0 and 1."
;
return
false
;
}
return
true
;
}
std
::
shared_ptr
<
TensorOp
>
CutMixBatchOperation
::
Build
()
{
std
::
shared_ptr
<
CutMixBatchOp
>
tensor_op
=
std
::
make_shared
<
CutMixBatchOp
>
(
image_batch_format_
,
alpha_
,
prob_
);
return
tensor_op
;
}
// CutOutOperation
CutOutOperation
::
CutOutOperation
(
int32_t
length
,
int32_t
num_patches
)
:
length_
(
length
),
num_patches_
(
num_patches
)
{}
...
...
mindspore/ccsrc/minddata/dataset/core/constants.h
浏览文件 @
113ff6ca
...
...
@@ -41,6 +41,12 @@ enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 };
// Possible values for Border types
enum
class
BorderType
{
kConstant
=
0
,
kEdge
=
1
,
kReflect
=
2
,
kSymmetric
=
3
};
// Possible values for Image format types in a batch
enum
class
ImageBatchFormat
{
kNHWC
=
0
,
kNCHW
=
1
};
// Possible values for Image format types
enum
class
ImageFormat
{
HWC
=
0
,
CHW
=
1
,
HW
=
2
};
// Possible interpolation modes
enum
class
InterpolationMode
{
kLinear
=
0
,
kNearestNeighbour
=
1
,
kCubic
=
2
,
kArea
=
3
};
...
...
mindspore/ccsrc/minddata/dataset/include/transforms.h
浏览文件 @
113ff6ca
...
...
@@ -49,6 +49,7 @@ namespace vision {
// Transform Op classes (in alphabetical order)
class
CenterCropOperation
;
class
CropOperation
;
class
CutMixBatchOperation
;
class
CutOutOperation
;
class
DecodeOperation
;
class
HwcToChwOperation
;
...
...
@@ -86,6 +87,16 @@ std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
/// \return Shared pointer to the current TensorOp
std
::
shared_ptr
<
CropOperation
>
Crop
(
std
::
vector
<
int32_t
>
coordinates
,
std
::
vector
<
int32_t
>
size
);
/// \brief Function to apply CutMix on a batch of images
/// \notes Masks a random section of each image with the corresponding part of another randomly selected image in
/// that batch
/// \param[in] image_batch_format The format of the batch
/// \param[in] alpha The hyperparameter of beta distribution (default = 1.0)
/// \param[in] prob The probability by which CutMix is applied to each image (default = 1.0)
/// \return Shared pointer to the current TensorOp
std
::
shared_ptr
<
CutMixBatchOperation
>
CutMixBatch
(
ImageBatchFormat
image_batch_format
,
float
alpha
=
1.0
,
float
prob
=
1.0
);
/// \brief Function to create a CutOut TensorOp
/// \notes Randomly cut (mask) out a given number of square patches from the input image
/// \param[in] length Integer representing the side length of each square patch
...
...
@@ -305,6 +316,22 @@ class CropOperation : public TensorOperation {
std
::
vector
<
int32_t
>
size_
;
};
class
CutMixBatchOperation
:
public
TensorOperation
{
public:
explicit
CutMixBatchOperation
(
ImageBatchFormat
image_batch_format
,
float
alpha
=
1.0
,
float
prob
=
1.0
);
~
CutMixBatchOperation
()
=
default
;
std
::
shared_ptr
<
TensorOp
>
Build
()
override
;
bool
ValidateParams
()
override
;
private:
float
alpha_
;
float
prob_
;
ImageBatchFormat
image_batch_format_
;
};
class
CutOutOperation
:
public
TensorOperation
{
public:
explicit
CutOutOperation
(
int32_t
length
,
int32_t
num_patches
=
1
);
...
...
@@ -318,6 +345,7 @@ class CutOutOperation : public TensorOperation {
private:
int32_t
length_
;
int32_t
num_patches_
;
ImageBatchFormat
image_batch_format_
;
};
class
DecodeOperation
:
public
TensorOperation
{
...
...
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
浏览文件 @
113ff6ca
...
...
@@ -655,7 +655,7 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
TensorShape
remaining
({
-
1
});
std
::
vector
<
int64_t
>
index
(
tensor_shape
.
size
(),
0
);
if
(
tensor_shape
.
size
()
<=
1
)
{
RETURN_STATUS_UNEXPECTED
(
"Tensor must be at least 2-D in order to unpack"
);
RETURN_STATUS_UNEXPECTED
(
"Tensor must be at least 2-D in order to unpack
.
"
);
}
TensorShape
element_shape
(
std
::
vector
<
int64_t
>
(
tensor_shape
.
begin
()
+
1
,
tensor_shape
.
end
()));
...
...
@@ -664,15 +664,48 @@ Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
std
::
shared_ptr
<
Tensor
>
out
;
RETURN_IF_NOT_OK
(
input
->
StartAddrOfIndex
(
index
,
&
start_addr_of_index
,
&
remaining
));
RETURN_IF_NOT_OK
(
input
->
CreateFromMemory
(
element_shape
,
input
->
type
(),
start_addr_of_index
,
&
out
));
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
element_shape
,
input
->
type
(),
start_addr_of_index
,
&
out
));
std
::
shared_ptr
<
CVTensor
>
cv_out
=
CVTensor
::
AsCVTensor
(
std
::
move
(
out
));
if
(
!
cv_out
->
mat
().
data
)
{
RETURN_STATUS_UNEXPECTED
(
"Could not convert to CV Tensor"
);
RETURN_STATUS_UNEXPECTED
(
"Could not convert to CV Tensor
.
"
);
}
output
->
push_back
(
cv_out
);
}
return
Status
::
OK
();
}
Status
BatchTensorToTensorVector
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
{
std
::
vector
<
int64_t
>
tensor_shape
=
input
->
shape
().
AsVector
();
TensorShape
remaining
({
-
1
});
std
::
vector
<
int64_t
>
index
(
tensor_shape
.
size
(),
0
);
if
(
tensor_shape
.
size
()
<=
1
)
{
RETURN_STATUS_UNEXPECTED
(
"Tensor must be at least 2-D in order to unpack."
);
}
TensorShape
element_shape
(
std
::
vector
<
int64_t
>
(
tensor_shape
.
begin
()
+
1
,
tensor_shape
.
end
()));
for
(;
index
[
0
]
<
tensor_shape
[
0
];
index
[
0
]
++
)
{
uchar
*
start_addr_of_index
=
nullptr
;
std
::
shared_ptr
<
Tensor
>
out
;
RETURN_IF_NOT_OK
(
input
->
StartAddrOfIndex
(
index
,
&
start_addr_of_index
,
&
remaining
));
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
element_shape
,
input
->
type
(),
start_addr_of_index
,
&
out
));
output
->
push_back
(
out
);
}
return
Status
::
OK
();
}
Status
TensorVectorToBatchTensor
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
if
(
input
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"TensorVectorToBatchTensor: Received an empty vector."
);
}
std
::
vector
<
int64_t
>
tensor_shape
=
input
.
front
()
->
shape
().
AsVector
();
tensor_shape
.
insert
(
tensor_shape
.
begin
(),
input
.
size
());
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
TensorShape
(
tensor_shape
),
input
.
at
(
0
)
->
type
(),
output
));
for
(
int
i
=
0
;
i
<
input
.
size
();
i
++
)
{
RETURN_IF_NOT_OK
((
*
output
)
->
InsertTensor
({
i
},
input
[
i
]));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
浏览文件 @
113ff6ca
...
...
@@ -158,11 +158,24 @@ Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
std
::
shared_ptr
<
Tensor
>
append
);
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors
///
@
param input[in] input tensor
///
@param output[out] output tensor
///
@
return Status ok/error
///
\
param input[in] input tensor
///
\param output[out] output vector of CVTensors
///
\
return Status ok/error
Status
BatchTensorToCVTensorVector
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
CVTensor
>>
*
output
);
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional Tensors
/// \param input[in] input tensor
/// \param output[out] output vector of tensors
/// \return Status ok/error
Status
BatchTensorToTensorVector
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
);
/// Convert a vector of (n-1)-dimensional Tensors to an n-dimensional Tensor
/// \param input[in] input vector of tensors
/// \param output[out] output tensor
/// \return Status ok/error
Status
TensorVectorToBatchTensor
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
);
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt
浏览文件 @
113ff6ca
...
...
@@ -7,6 +7,7 @@ add_library(kernels-image OBJECT
center_crop_op.cc
crop_op.cc
cut_out_op.cc
cutmix_batch_op.cc
decode_op.cc
equalize_op.cc
hwc_to_chw_op.cc
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
0 → 100644
浏览文件 @
113ff6ca
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <utility>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
CutMixBatchOp
::
CutMixBatchOp
(
ImageBatchFormat
image_batch_format
,
float
alpha
,
float
prob
)
:
image_batch_format_
(
image_batch_format
),
alpha_
(
alpha
),
prob_
(
prob
)
{
rnd_
.
seed
(
GetSeed
());
}
void
CutMixBatchOp
::
GetCropBox
(
int
height
,
int
width
,
float
lam
,
int
*
x
,
int
*
y
,
int
*
crop_width
,
int
*
crop_height
)
{
float
cut_ratio
=
1
-
lam
;
int
cut_w
=
static_cast
<
int
>
(
width
*
cut_ratio
);
int
cut_h
=
static_cast
<
int
>
(
height
*
cut_ratio
);
std
::
uniform_int_distribution
<
int
>
width_uniform_distribution
(
0
,
width
);
std
::
uniform_int_distribution
<
int
>
height_uniform_distribution
(
0
,
height
);
int
cx
=
width_uniform_distribution
(
rnd_
);
int
x2
,
y2
;
int
cy
=
height_uniform_distribution
(
rnd_
);
*
x
=
std
::
clamp
(
cx
-
cut_w
/
2
,
0
,
width
-
1
);
// horizontal coordinate of left side of crop box
*
y
=
std
::
clamp
(
cy
-
cut_h
/
2
,
0
,
height
-
1
);
// vertical coordinate of the top side of crop box
x2
=
std
::
clamp
(
cx
+
cut_w
/
2
,
0
,
width
-
1
);
// horizontal coordinate of right side of crop box
y2
=
std
::
clamp
(
cy
+
cut_h
/
2
,
0
,
height
-
1
);
// vertical coordinate of the bottom side of crop box
*
crop_width
=
std
::
clamp
(
x2
-
*
x
,
1
,
width
-
1
);
*
crop_height
=
std
::
clamp
(
y2
-
*
y
,
1
,
height
-
1
);
}
Status
CutMixBatchOp
::
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
{
if
(
input
.
size
()
<
2
)
{
RETURN_STATUS_UNEXPECTED
(
"Both images and labels columns are required for this operation"
);
}
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
images
;
std
::
vector
<
int64_t
>
image_shape
=
input
.
at
(
0
)
->
shape
().
AsVector
();
std
::
vector
<
int64_t
>
label_shape
=
input
.
at
(
1
)
->
shape
().
AsVector
();
// Check inputs
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling CutMixBatch."
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Label's must be in one-hot format and in a batch"
);
}
if
((
image_shape
[
1
]
!=
1
&&
image_shape
[
1
]
!=
3
)
&&
image_batch_format_
==
ImageBatchFormat
::
kNCHW
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Image doesn't match the given image format."
);
}
if
((
image_shape
[
3
]
!=
1
&&
image_shape
[
3
]
!=
3
)
&&
image_batch_format_
==
ImageBatchFormat
::
kNHWC
)
{
RETURN_STATUS_UNEXPECTED
(
"CutMixBatch: Image doesn't match the given image format."
);
}
// Move images into a vector of Tensors
RETURN_IF_NOT_OK
(
BatchTensorToTensorVector
(
input
.
at
(
0
),
&
images
));
// Calculate random labels
std
::
vector
<
int64_t
>
rand_indx
;
for
(
int64_t
i
=
0
;
i
<
images
.
size
();
i
++
)
rand_indx
.
push_back
(
i
);
std
::
shuffle
(
rand_indx
.
begin
(),
rand_indx
.
end
(),
rnd_
);
std
::
gamma_distribution
<
float
>
gamma_distribution
(
alpha_
,
1
);
std
::
uniform_real_distribution
<
double
>
uniform_distribution
(
0.0
,
1.0
);
// Tensor holding the output labels
std
::
shared_ptr
<
Tensor
>
out_labels
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
TensorShape
(
label_shape
),
DataType
(
DataType
::
DE_FLOAT32
),
&
out_labels
));
// Compute labels and images
for
(
int
i
=
0
;
i
<
image_shape
[
0
];
i
++
)
{
// Calculating lambda
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
float
x1
=
gamma_distribution
(
rnd_
);
float
x2
=
gamma_distribution
(
rnd_
);
float
lam
=
x1
/
(
x1
+
x2
);
double
random_number
=
uniform_distribution
(
rnd_
);
if
(
random_number
<
prob_
)
{
int
x
,
y
,
crop_width
,
crop_height
;
float
label_lam
;
// lambda used for labels
// Get a random image
TensorShape
remaining
({
-
1
});
uchar
*
start_addr_of_index
=
nullptr
;
std
::
shared_ptr
<
Tensor
>
rand_image
;
RETURN_IF_NOT_OK
(
input
.
at
(
0
)
->
StartAddrOfIndex
({
rand_indx
[
i
],
0
,
0
,
0
},
&
start_addr_of_index
,
&
remaining
));
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
TensorShape
({
image_shape
[
1
],
image_shape
[
2
],
image_shape
[
3
]}),
input
.
at
(
0
)
->
type
(),
start_addr_of_index
,
&
rand_image
));
// Compute image
if
(
image_batch_format_
==
ImageBatchFormat
::
kNHWC
)
{
// NHWC Format
GetCropBox
(
static_cast
<
int32_t
>
(
image_shape
[
1
]),
static_cast
<
int32_t
>
(
image_shape
[
2
]),
lam
,
&
x
,
&
y
,
&
crop_width
,
&
crop_height
);
std
::
shared_ptr
<
Tensor
>
cropped
;
RETURN_IF_NOT_OK
(
Crop
(
rand_image
,
&
cropped
,
x
,
y
,
crop_width
,
crop_height
));
RETURN_IF_NOT_OK
(
MaskWithTensor
(
cropped
,
&
images
[
i
],
x
,
y
,
crop_width
,
crop_height
,
ImageFormat
::
HWC
));
label_lam
=
1
-
(
crop_width
*
crop_height
/
static_cast
<
float
>
(
image_shape
[
1
]
*
image_shape
[
2
]));
}
else
{
// NCHW Format
GetCropBox
(
static_cast
<
int32_t
>
(
image_shape
[
2
]),
static_cast
<
int32_t
>
(
image_shape
[
3
]),
lam
,
&
x
,
&
y
,
&
crop_width
,
&
crop_height
);
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
channels
;
// A vector holding channels of the CHW image
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
cropped_channels
;
// A vector holding the channels of the cropped CHW
RETURN_IF_NOT_OK
(
BatchTensorToTensorVector
(
rand_image
,
&
channels
));
for
(
auto
channel
:
channels
)
{
// Call crop for each single channel
std
::
shared_ptr
<
Tensor
>
cropped_channel
;
RETURN_IF_NOT_OK
(
Crop
(
channel
,
&
cropped_channel
,
x
,
y
,
crop_width
,
crop_height
));
cropped_channels
.
push_back
(
cropped_channel
);
}
std
::
shared_ptr
<
Tensor
>
cropped
;
// Merge channels to a single tensor
RETURN_IF_NOT_OK
(
TensorVectorToBatchTensor
(
cropped_channels
,
&
cropped
));
RETURN_IF_NOT_OK
(
MaskWithTensor
(
cropped
,
&
images
[
i
],
x
,
y
,
crop_width
,
crop_height
,
ImageFormat
::
CHW
));
label_lam
=
1
-
(
crop_width
*
crop_height
/
static_cast
<
float
>
(
image_shape
[
2
]
*
image_shape
[
3
]));
}
// Compute labels
for
(
int
j
=
0
;
j
<
label_shape
[
1
];
j
++
)
{
uint64_t
first_value
,
second_value
;
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
first_value
,
{
i
,
j
}));
RETURN_IF_NOT_OK
(
input
.
at
(
1
)
->
GetItemAt
(
&
second_value
,
{
rand_indx
[
i
]
%
label_shape
[
0
],
j
}));
RETURN_IF_NOT_OK
(
out_labels
->
SetItemAt
({
i
,
j
},
label_lam
*
first_value
+
(
1
-
label_lam
)
*
second_value
));
}
}
}
std
::
shared_ptr
<
Tensor
>
out_images
;
RETURN_IF_NOT_OK
(
TensorVectorToBatchTensor
(
images
,
&
out_images
));
// Move the output into a TensorRow
output
->
push_back
(
out_images
);
output
->
push_back
(
out_labels
);
return
Status
::
OK
();
}
void
CutMixBatchOp
::
Print
(
std
::
ostream
&
out
)
const
{
out
<<
"CutMixBatchOp: "
<<
"image_batch_format: "
<<
image_batch_format_
<<
"alpha: "
<<
alpha_
<<
", probability: "
<<
prob_
<<
"
\n
"
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.h
0 → 100644
浏览文件 @
113ff6ca
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
#include <memory>
#include <vector>
#include <random>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
class
CutMixBatchOp
:
public
TensorOp
{
public:
explicit
CutMixBatchOp
(
ImageBatchFormat
image_batch_format
,
float
alpha
,
float
prob
);
~
CutMixBatchOp
()
override
=
default
;
void
Print
(
std
::
ostream
&
out
)
const
override
;
void
GetCropBox
(
int
width
,
int
height
,
float
lam
,
int
*
x
,
int
*
y
,
int
*
crop_width
,
int
*
crop_height
);
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
;
std
::
string
Name
()
const
override
{
return
kCutMixBatchOp
;
}
private:
float
alpha_
;
float
prob_
;
ImageBatchFormat
image_batch_format_
;
std
::
mt19937
rnd_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_CUTMIXBATCH_OP_H_
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
浏览文件 @
113ff6ca
...
...
@@ -402,6 +402,62 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output)
}
}
Status
MaskWithTensor
(
const
std
::
shared_ptr
<
Tensor
>
&
sub_mat
,
std
::
shared_ptr
<
Tensor
>
*
input
,
int
x
,
int
y
,
int
crop_width
,
int
crop_height
,
ImageFormat
image_format
)
{
if
(
image_format
==
ImageFormat
::
HWC
)
{
if
((
*
input
)
->
Rank
()
!=
3
||
((
*
input
)
->
shape
()[
2
]
!=
1
&&
(
*
input
)
->
shape
()[
2
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: Image shape doesn't match the given image_format."
);
}
if
(
sub_mat
->
Rank
()
!=
3
||
(
sub_mat
->
shape
()[
2
]
!=
1
&&
sub_mat
->
shape
()[
2
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: sub_mat shape doesn't match the given image_format."
);
}
int
number_of_channels
=
(
*
input
)
->
shape
()[
2
];
for
(
int
i
=
0
;
i
<
crop_width
;
i
++
)
{
for
(
int
j
=
0
;
j
<
crop_height
;
j
++
)
{
for
(
int
c
=
0
;
c
<
number_of_channels
;
c
++
)
{
uint8_t
pixel_value
;
RETURN_IF_NOT_OK
(
sub_mat
->
GetItemAt
(
&
pixel_value
,
{
j
,
i
,
c
}));
RETURN_IF_NOT_OK
((
*
input
)
->
SetItemAt
({
y
+
j
,
x
+
i
,
c
},
pixel_value
));
}
}
}
}
else
if
(
image_format
==
ImageFormat
::
CHW
)
{
if
((
*
input
)
->
Rank
()
!=
3
||
((
*
input
)
->
shape
()[
0
]
!=
1
&&
(
*
input
)
->
shape
()[
0
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: Image shape doesn't match the given image_format."
);
}
if
(
sub_mat
->
Rank
()
!=
3
||
(
sub_mat
->
shape
()[
0
]
!=
1
&&
sub_mat
->
shape
()[
0
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: sub_mat shape doesn't match the given image_format."
);
}
int
number_of_channels
=
(
*
input
)
->
shape
()[
0
];
for
(
int
i
=
0
;
i
<
crop_width
;
i
++
)
{
for
(
int
j
=
0
;
j
<
crop_height
;
j
++
)
{
for
(
int
c
=
0
;
c
<
number_of_channels
;
c
++
)
{
uint8_t
pixel_value
;
RETURN_IF_NOT_OK
(
sub_mat
->
GetItemAt
(
&
pixel_value
,
{
c
,
j
,
i
}));
RETURN_IF_NOT_OK
((
*
input
)
->
SetItemAt
({
c
,
y
+
j
,
x
+
i
},
pixel_value
));
}
}
}
}
else
if
(
image_format
==
ImageFormat
::
HW
)
{
if
((
*
input
)
->
Rank
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: Image shape doesn't match the given image_format."
);
}
if
(
sub_mat
->
Rank
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: sub_mat shape doesn't match the given image_format."
);
}
for
(
int
i
=
0
;
i
<
crop_width
;
i
++
)
{
for
(
int
j
=
0
;
j
<
crop_height
;
j
++
)
{
uint8_t
pixel_value
;
RETURN_IF_NOT_OK
(
sub_mat
->
GetItemAt
(
&
pixel_value
,
{
j
,
i
}));
RETURN_IF_NOT_OK
((
*
input
)
->
SetItemAt
({
y
+
j
,
x
+
i
},
pixel_value
));
}
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"MaskWithTensor: Image format must be CHW, HWC, or HW."
);
}
return
Status
::
OK
();
}
Status
SwapRedAndBlue
(
std
::
shared_ptr
<
Tensor
>
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
try
{
std
::
shared_ptr
<
CVTensor
>
input_cv
=
CVTensor
::
AsCVTensor
(
std
::
move
(
input
));
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h
浏览文件 @
113ff6ca
...
...
@@ -120,6 +120,19 @@ Status Crop(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
/// \param output: Tensor of shape <C,H,W> or <H,W> and same input type.
Status
HwcToChw
(
std
::
shared_ptr
<
Tensor
>
input
,
std
::
shared_ptr
<
Tensor
>
*
output
);
/// \brief Masks the given part of the input image with a another image (sub_mat)
/// \param[in] sub_mat The image we want to mask with
/// \param[in] input The pointer to the image we want to mask
/// \param[in] x The horizontal coordinate of left side of crop box
/// \param[in] y The vertical coordinate of the top side of crop box
/// \param[in] width The width of the mask box
/// \param[in] height The height of the mask box
/// \param[in] image_format The format of the image (CHW or HWC)
/// \param[out] input Masks the input image in-place and returns it
/// @return Status ok/error
Status
MaskWithTensor
(
const
std
::
shared_ptr
<
Tensor
>
&
sub_mat
,
std
::
shared_ptr
<
Tensor
>
*
input
,
int
x
,
int
y
,
int
width
,
int
height
,
ImageFormat
image_format
);
/// \brief Swap the red and blue pixels (RGB <-> BGR)
/// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor.
/// \param output: Swapped image of same shape and type
...
...
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
浏览文件 @
113ff6ca
...
...
@@ -37,10 +37,12 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
std
::
vector
<
int64_t
>
label_shape
=
input
.
at
(
1
)
->
shape
().
AsVector
();
// Check inputs
if
(
label_shape
.
size
()
!=
2
||
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
if
(
image_shape
.
size
()
!=
4
||
image_shape
[
0
]
!=
label_shape
[
0
])
{
RETURN_STATUS_UNEXPECTED
(
"You must batch before calling MixUpBatch"
);
}
if
(
label_shape
.
size
()
!=
2
)
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Label's must be in one-hot format and in a batch"
);
}
if
((
image_shape
[
1
]
!=
1
&&
image_shape
[
1
]
!=
3
)
&&
(
image_shape
[
3
]
!=
1
&&
image_shape
[
3
]
!=
3
))
{
RETURN_STATUS_UNEXPECTED
(
"MixUpBatch: Images must be in the shape of HWC or CHW"
);
}
...
...
mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h
浏览文件 @
113ff6ca
...
...
@@ -94,6 +94,7 @@ constexpr char kAutoContrastOp[] = "AutoContrastOp";
constexpr
char
kBoundingBoxAugmentOp
[]
=
"BoundingBoxAugmentOp"
;
constexpr
char
kDecodeOp
[]
=
"DecodeOp"
;
constexpr
char
kCenterCropOp
[]
=
"CenterCropOp"
;
constexpr
char
kCutMixBatchOp
[]
=
"CutMixBatchOp"
;
constexpr
char
kCutOutOp
[]
=
"CutOutOp"
;
constexpr
char
kCropOp
[]
=
"CropOp"
;
constexpr
char
kEqualizeOp
[]
=
"EqualizeOp"
;
...
...
mindspore/dataset/transforms/vision/c_transforms.py
浏览文件 @
113ff6ca
...
...
@@ -43,13 +43,14 @@ Examples:
import
numbers
import
mindspore._c_dataengine
as
cde
from
.utils
import
Inter
,
Border
from
.utils
import
Inter
,
Border
,
ImageBatchFormat
from
.validators
import
check_prob
,
check_crop
,
check_resize_interpolation
,
check_random_resize_crop
,
\
check_mix_up_batch_c
,
check_normalize_c
,
check_random_crop
,
check_random_color_adjust
,
check_random_rotation
,
\
check_range
,
check_resize
,
check_rescale
,
check_pad
,
check_cutout
,
\
check_uniform_augment_cpp
,
\
check_bounding_box_augment_cpp
,
check_random_select_subpolicy_op
,
check_auto_contrast
,
check_random_affine
,
\
check_random_solarize
,
check_soft_dvpp_decode_random_crop_resize_jpeg
,
check_positive_degrees
,
FLOAT_MAX_INTEGER
check_random_solarize
,
check_soft_dvpp_decode_random_crop_resize_jpeg
,
check_positive_degrees
,
FLOAT_MAX_INTEGER
,
\
check_cut_mix_batch_c
DE_C_INTER_MODE
=
{
Inter
.
NEAREST
:
cde
.
InterpolationMode
.
DE_INTER_NEAREST_NEIGHBOUR
,
Inter
.
LINEAR
:
cde
.
InterpolationMode
.
DE_INTER_LINEAR
,
...
...
@@ -60,6 +61,8 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
Border
.
REFLECT
:
cde
.
BorderType
.
DE_BORDER_REFLECT
,
Border
.
SYMMETRIC
:
cde
.
BorderType
.
DE_BORDER_SYMMETRIC
}
DE_C_IMAGE_BATCH_FORMAT
=
{
ImageBatchFormat
.
NHWC
:
cde
.
ImageBatchFormat
.
DE_IMAGE_BATCH_FORMAT_NHWC
,
ImageBatchFormat
.
NCHW
:
cde
.
ImageBatchFormat
.
DE_IMAGE_BATCH_FORMAT_NCHW
}
def
parse_padding
(
padding
):
if
isinstance
(
padding
,
numbers
.
Number
):
...
...
@@ -143,6 +146,33 @@ class Decode(cde.DecodeOp):
super
().
__init__
(
self
.
rgb
)
class
CutMixBatch
(
cde
.
CutMixBatchOp
):
"""
Apply CutMix transformation on input batch of images and labels.
Note that you need to make labels into one-hot format and batch before calling this function.
Args:
image_batch_format (Image Batch Format): The method of padding. Can be any of
[ImageBatchFormat.NHWC, ImageBatchFormat.NCHW]
alpha (float): hyperparameter of beta distribution (default = 1.0).
prob (float): The probability by which CutMix is applied to each image (default = 1.0).
Examples:
>>> one_hot_op = data.OneHot(num_classes=10)
>>> data = data.map(input_columns=["label"], operations=one_hot_op)
>>> cutmix_batch_op = vision.CutMixBatch(ImageBatchFormat.NHWC, 1.0, 0.5)
>>> data = data.batch(5)
>>> data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
"""
@
check_cut_mix_batch_c
def
__init__
(
self
,
image_batch_format
,
alpha
=
1.0
,
prob
=
1.0
):
self
.
image_batch_format
=
image_batch_format
.
value
self
.
alpha
=
alpha
self
.
prob
=
prob
super
().
__init__
(
DE_C_IMAGE_BATCH_FORMAT
[
image_batch_format
],
alpha
,
prob
)
class
CutOut
(
cde
.
CutOutOp
):
"""
Randomly cut (mask) out a given number of square patches from the input Numpy image array.
...
...
mindspore/dataset/transforms/vision/utils.py
浏览文件 @
113ff6ca
...
...
@@ -30,3 +30,9 @@ class Border(str, Enum):
EDGE
:
str
=
"edge"
REFLECT
:
str
=
"reflect"
SYMMETRIC
:
str
=
"symmetric"
# Image Batch Format
class
ImageBatchFormat
(
IntEnum
):
NHWC
=
0
NCHW
=
1
mindspore/dataset/transforms/vision/validators.py
浏览文件 @
113ff6ca
...
...
@@ -19,7 +19,7 @@ from functools import wraps
import
numpy
as
np
from
mindspore._c_dataengine
import
TensorOp
from
.utils
import
Inter
,
Border
from
.utils
import
Inter
,
Border
,
ImageBatchFormat
from
...core.validator_helpers
import
check_value
,
check_uint8
,
FLOAT_MAX_INTEGER
,
check_pos_float32
,
\
check_2tuple
,
check_range
,
check_positive
,
INT32_MAX
,
parse_user_args
,
type_check
,
type_check_list
,
\
check_tensor_op
,
UINT8_MAX
...
...
@@ -37,6 +37,20 @@ def check_crop_size(size):
raise
TypeError
(
"Size should be a single integer or a list/tuple (h, w) of length 2."
)
def
check_cut_mix_batch_c
(
method
):
"""Wrapper method to check the parameters of CutMixBatch."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
image_batch_format
,
alpha
,
prob
],
_
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
type_check
(
image_batch_format
,
(
ImageBatchFormat
,),
"image_batch_format"
)
check_pos_float32
(
alpha
)
check_value
(
prob
,
[
0
,
1
],
"prob"
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
def
check_resize_size
(
size
):
"""Wrapper method to check the parameters of resize."""
if
isinstance
(
size
,
int
):
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
113ff6ca
...
...
@@ -20,6 +20,7 @@ SET(DE_UT_SRCS
circular_pool_test.cc
client_config_test.cc
connector_test.cc
cutmix_batch_op_test.cc
cut_out_op_test.cc
datatype_test.cc
decode_op_test.cc
...
...
tests/ut/cpp/dataset/c_api_transforms_test.cc
浏览文件 @
113ff6ca
...
...
@@ -25,6 +25,177 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchSuccess1
)
{
// Testing CutMixBatch on a batch of CHW images
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
int
number_of_classes
=
10
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
hwc_to_chw
=
vision
::
HWC2CHW
();
EXPECT_NE
(
hwc_to_chw
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
hwc_to_chw
},{
"image"
});
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
number_of_classes
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNCHW
,
1.0
,
1.0
);
EXPECT_NE
(
cutmix_batch_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
cutmix_batch_op
},
{
"image"
,
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Label shape: "
<<
label
->
shape
();
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
3
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
32
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
label
->
shape
().
AsVector
().
size
()
==
2
&&
batch_size
==
label
->
shape
()[
0
]
&&
number_of_classes
==
label
->
shape
()[
1
],
true
);
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
2
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchSuccess2
)
{
// Calling CutMixBatch on a batch of HWC images with default values of alpha and prob
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
int
number_of_classes
=
10
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
number_of_classes
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
);
EXPECT_NE
(
cutmix_batch_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
cutmix_batch_op
},
{
"image"
,
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
auto
label
=
row
[
"label"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
MS_LOG
(
INFO
)
<<
"Label shape: "
<<
label
->
shape
();
EXPECT_EQ
(
image
->
shape
().
AsVector
().
size
()
==
4
&&
batch_size
==
image
->
shape
()[
0
]
&&
32
==
image
->
shape
()[
1
]
&&
32
==
image
->
shape
()[
2
]
&&
3
==
image
->
shape
()[
3
],
true
);
EXPECT_EQ
(
label
->
shape
().
AsVector
().
size
()
==
2
&&
batch_size
==
label
->
shape
()[
0
]
&&
number_of_classes
==
label
->
shape
()[
1
],
true
);
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
2
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchFail1
)
{
// Must fail because alpha can't be negative
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
-
1
,
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutMixBatchFail2
)
{
// Must fail because prob can't be negative
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
10
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
5
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create objects for the tensor ops
std
::
shared_ptr
<
TensorOperation
>
one_hot_op
=
vision
::
OneHot
(
10
);
EXPECT_NE
(
one_hot_op
,
nullptr
);
// Create a Map operation on ds
ds
=
ds
->
Map
({
one_hot_op
},{
"label"
});
EXPECT_NE
(
ds
,
nullptr
);
std
::
shared_ptr
<
TensorOperation
>
cutmix_batch_op
=
vision
::
CutMixBatch
(
mindspore
::
dataset
::
ImageBatchFormat
::
kNHWC
,
1
,
-
0.5
);
EXPECT_EQ
(
cutmix_batch_op
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCutOut
)
{
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
...
...
tests/ut/cpp/dataset/cutmix_batch_op_test.cc
0 → 100644
浏览文件 @
113ff6ca
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/cutmix_batch_op.h"
#include "utils/log_adapter.h"
using
namespace
mindspore
::
dataset
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
INFO
;
class
MindDataTestCutMixBatchOp
:
public
UT
::
CVOP
::
CVOpCommon
{
protected:
MindDataTestCutMixBatchOp
()
:
CVOpCommon
()
{}
};
TEST_F
(
MindDataTestCutMixBatchOp
,
TestSuccess1
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestCutMixBatchOp success1 case"
;
std
::
shared_ptr
<
Tensor
>
batched_tensor
;
std
::
shared_ptr
<
Tensor
>
batched_labels
;
Tensor
::
CreateEmpty
(
TensorShape
({
2
,
input_tensor_
->
shape
()[
0
],
input_tensor_
->
shape
()[
1
],
input_tensor_
->
shape
()[
2
]}),
input_tensor_
->
type
(),
&
batched_tensor
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
batched_tensor
->
InsertTensor
({
i
},
input_tensor_
);
}
Tensor
::
CreateFromVector
(
std
::
vector
<
uint32_t
>
({
0
,
1
,
1
,
0
}),
TensorShape
({
2
,
2
}),
&
batched_labels
);
std
::
shared_ptr
<
CutMixBatchOp
>
op
=
std
::
make_shared
<
CutMixBatchOp
>
(
ImageBatchFormat
::
kNHWC
,
1.0
,
1.0
);
TensorRow
in
;
in
.
push_back
(
batched_tensor
);
in
.
push_back
(
batched_labels
);
TensorRow
out
;
ASSERT_TRUE
(
op
->
Compute
(
in
,
&
out
).
IsOk
());
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
0
],
out
.
at
(
0
)
->
shape
()[
0
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
1
],
out
.
at
(
0
)
->
shape
()[
1
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
2
],
out
.
at
(
0
)
->
shape
()[
2
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
3
],
out
.
at
(
0
)
->
shape
()[
3
]);
EXPECT_EQ
(
in
.
at
(
1
)
->
shape
()[
0
],
out
.
at
(
1
)
->
shape
()[
0
]);
EXPECT_EQ
(
in
.
at
(
1
)
->
shape
()[
1
],
out
.
at
(
1
)
->
shape
()[
1
]);
}
TEST_F
(
MindDataTestCutMixBatchOp
,
TestSuccess2
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestCutMixBatchOp success2 case"
;
std
::
shared_ptr
<
Tensor
>
batched_tensor
;
std
::
shared_ptr
<
Tensor
>
batched_labels
;
std
::
shared_ptr
<
Tensor
>
chw_tensor
;
ASSERT_TRUE
(
HwcToChw
(
input_tensor_
,
&
chw_tensor
).
IsOk
());
Tensor
::
CreateEmpty
(
TensorShape
({
2
,
chw_tensor
->
shape
()[
0
],
chw_tensor
->
shape
()[
1
],
chw_tensor
->
shape
()[
2
]}),
chw_tensor
->
type
(),
&
batched_tensor
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
batched_tensor
->
InsertTensor
({
i
},
chw_tensor
);
}
Tensor
::
CreateFromVector
(
std
::
vector
<
uint32_t
>
({
0
,
1
,
1
,
0
}),
TensorShape
({
2
,
2
}),
&
batched_labels
);
std
::
shared_ptr
<
CutMixBatchOp
>
op
=
std
::
make_shared
<
CutMixBatchOp
>
(
ImageBatchFormat
::
kNCHW
,
1.0
,
0.5
);
TensorRow
in
;
in
.
push_back
(
batched_tensor
);
in
.
push_back
(
batched_labels
);
TensorRow
out
;
ASSERT_TRUE
(
op
->
Compute
(
in
,
&
out
).
IsOk
());
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
0
],
out
.
at
(
0
)
->
shape
()[
0
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
1
],
out
.
at
(
0
)
->
shape
()[
1
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
2
],
out
.
at
(
0
)
->
shape
()[
2
]);
EXPECT_EQ
(
in
.
at
(
0
)
->
shape
()[
3
],
out
.
at
(
0
)
->
shape
()[
3
]);
EXPECT_EQ
(
in
.
at
(
1
)
->
shape
()[
0
],
out
.
at
(
1
)
->
shape
()[
0
]);
EXPECT_EQ
(
in
.
at
(
1
)
->
shape
()[
1
],
out
.
at
(
1
)
->
shape
()[
1
]);
}
TEST_F
(
MindDataTestCutMixBatchOp
,
TestFail1
)
{
// This is a fail case because our labels are not batched and are 1-dimensional
MS_LOG
(
INFO
)
<<
"Doing MindDataTestCutMixBatchOp fail1 case"
;
std
::
shared_ptr
<
Tensor
>
labels
;
Tensor
::
CreateFromVector
(
std
::
vector
<
uint32_t
>
({
0
,
1
,
1
,
0
}),
TensorShape
({
4
}),
&
labels
);
std
::
shared_ptr
<
CutMixBatchOp
>
op
=
std
::
make_shared
<
CutMixBatchOp
>
(
ImageBatchFormat
::
kNHWC
,
1.0
,
1.0
);
TensorRow
in
;
in
.
push_back
(
input_tensor_
);
in
.
push_back
(
labels
);
TensorRow
out
;
ASSERT_FALSE
(
op
->
Compute
(
in
,
&
out
).
IsOk
());
}
TEST_F
(
MindDataTestCutMixBatchOp
,
TestFail2
)
{
// This should fail because the image_batch_format provided is not the same as the actual format of the images
MS_LOG
(
INFO
)
<<
"Doing MindDataTestCutMixBatchOp fail2 case"
;
std
::
shared_ptr
<
Tensor
>
batched_tensor
;
std
::
shared_ptr
<
Tensor
>
batched_labels
;
Tensor
::
CreateEmpty
(
TensorShape
({
2
,
input_tensor_
->
shape
()[
0
],
input_tensor_
->
shape
()[
1
],
input_tensor_
->
shape
()[
2
]}),
input_tensor_
->
type
(),
&
batched_tensor
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
batched_tensor
->
InsertTensor
({
i
},
input_tensor_
);
}
Tensor
::
CreateFromVector
(
std
::
vector
<
uint32_t
>
({
0
,
1
,
1
,
0
}),
TensorShape
({
2
,
2
}),
&
batched_labels
);
std
::
shared_ptr
<
CutMixBatchOp
>
op
=
std
::
make_shared
<
CutMixBatchOp
>
(
ImageBatchFormat
::
kNCHW
,
1.0
,
1.0
);
TensorRow
in
;
in
.
push_back
(
batched_tensor
);
in
.
push_back
(
batched_labels
);
TensorRow
out
;
ASSERT_FALSE
(
op
->
Compute
(
in
,
&
out
).
IsOk
());
}
tests/ut/data/dataset/golden/cutmix_batch_c_nchw_result.npz
0 → 100644
浏览文件 @
113ff6ca
文件已添加
tests/ut/data/dataset/golden/cutmix_batch_c_nhwc_result.npz
0 → 100644
浏览文件 @
113ff6ca
文件已添加
tests/ut/python/dataset/test_cutmix_batch_op.py
0 → 100644
浏览文件 @
113ff6ca
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the CutMixBatch op in DE
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset.transforms.c_transforms
as
data_trans
import
mindspore.dataset.transforms.vision.utils
as
mode
from
mindspore
import
log
as
logger
from
util
import
save_and_check_md5
,
diff_mse
,
visualize_list
,
config_get_set_seed
,
\
config_get_set_num_parallel_workers
DATA_DIR
=
"../data/dataset/testCifar10Data"
GENERATE_GOLDEN
=
False
def
test_cutmix_batch_success1
(
plot
=
False
):
"""
Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
"""
logger
.
info
(
"test_cutmix_batch_success1"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# CutMix Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
hwc2chw_op
=
vision
.
HWC2CHW
()
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc2chw_op
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NCHW
,
2.0
,
0.5
)
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
images_cutmix
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
.
transpose
(
0
,
2
,
3
,
1
)
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
.
transpose
(
0
,
2
,
3
,
1
),
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_cutmix
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_cutmix
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_success2
(
plot
=
False
):
"""
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
"""
logger
.
info
(
"test_cutmix_batch_success2"
)
# Original Images
ds_original
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
ds_original
=
ds_original
.
batch
(
5
,
drop_remainder
=
True
)
images_original
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
ds_original
):
if
idx
==
0
:
images_original
=
image
else
:
images_original
=
np
.
append
(
images_original
,
image
,
axis
=
0
)
# CutMix Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
images_cutmix
=
None
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
if
plot
:
visualize_list
(
images_original
,
images_cutmix
)
num_samples
=
images_original
.
shape
[
0
]
mse
=
np
.
zeros
(
num_samples
)
for
i
in
range
(
num_samples
):
mse
[
i
]
=
diff_mse
(
images_cutmix
[
i
],
images_original
[
i
])
logger
.
info
(
"MSE= {}"
.
format
(
str
(
np
.
mean
(
mse
))))
def
test_cutmix_batch_nhwc_md5
():
"""
Test CutMixBatch on a batch of HWC images with MD5:
"""
logger
.
info
(
"test_cutmix_batch_nhwc_md5"
)
original_seed
=
config_get_set_seed
(
0
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# CutMixBatch Images
data
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data
=
data
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data
=
data
.
batch
(
5
,
drop_remainder
=
True
)
data
=
data
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
filename
=
"cutmix_batch_c_nhwc_result.npz"
save_and_check_md5
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_cutmix_batch_nchw_md5
():
"""
Test CutMixBatch on a batch of CHW images with MD5:
"""
logger
.
info
(
"test_cutmix_batch_nchw_md5"
)
original_seed
=
config_get_set_seed
(
0
)
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
# CutMixBatch Images
data
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
hwc2chw_op
=
vision
.
HWC2CHW
()
data
=
data
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc2chw_op
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data
=
data
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NCHW
)
data
=
data
.
batch
(
5
,
drop_remainder
=
True
)
data
=
data
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
filename
=
"cutmix_batch_c_nchw_result.npz"
save_and_check_md5
(
data
,
filename
,
generate_golden
=
GENERATE_GOLDEN
)
# Restore config setting
ds
.
config
.
set_seed
(
original_seed
)
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
def
test_cutmix_batch_fail1
():
"""
Test CutMixBatch Fail 1
We expect this to fail because the images and labels are not batched
"""
logger
.
info
(
"test_cutmix_batch_fail1"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
with
pytest
.
raises
(
RuntimeError
)
as
error
:
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"You must batch before calling CutMixBatch"
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail2
():
"""
Test CutMixBatch Fail 2
We expect this to fail because alpha is negative
"""
logger
.
info
(
"test_cutmix_batch_fail2"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
-
1
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail3
():
"""
Test CutMixBatch Fail 2
We expect this to fail because prob is larger than 1
"""
logger
.
info
(
"test_cutmix_batch_fail3"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
1
,
2
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail4
():
"""
Test CutMixBatch Fail 2
We expect this to fail because prob is negative
"""
logger
.
info
(
"test_cutmix_batch_fail4"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
with
pytest
.
raises
(
ValueError
)
as
error
:
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
,
1
,
-
1
)
error_message
=
"Input is not within the required interval"
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail5
():
"""
Test CutMixBatch op
We expect this to fail because label column is not passed to cutmix_batch
"""
logger
.
info
(
"test_cutmix_batch_fail5"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
cutmix_batch_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error
:
images_cutmix
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"Both images and labels columns are required"
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail6
():
"""
Test CutMixBatch op
We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
"""
logger
.
info
(
"test_cutmix_batch_fail6"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
one_hot_op
=
data_trans
.
OneHot
(
num_classes
=
10
)
data1
=
data1
.
map
(
input_columns
=
[
"label"
],
operations
=
one_hot_op
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NCHW
)
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error
:
images_cutmix
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"CutMixBatch: Image doesn't match the given image format."
assert
error_message
in
str
(
error
.
value
)
def
test_cutmix_batch_fail7
():
"""
Test CutMixBatch op
We expect this to fail because labels are not in one-hot format
"""
logger
.
info
(
"test_cutmix_batch_fail7"
)
# CutMixBatch Images
data1
=
ds
.
Cifar10Dataset
(
DATA_DIR
,
num_samples
=
10
,
shuffle
=
False
)
cutmix_batch_op
=
vision
.
CutMixBatch
(
mode
.
ImageBatchFormat
.
NHWC
)
data1
=
data1
.
batch
(
5
,
drop_remainder
=
True
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
,
"label"
],
operations
=
cutmix_batch_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error
:
images_cutmix
=
np
.
array
([])
for
idx
,
(
image
,
_
)
in
enumerate
(
data1
):
if
idx
==
0
:
images_cutmix
=
image
else
:
images_cutmix
=
np
.
append
(
images_cutmix
,
image
,
axis
=
0
)
error_message
=
"CutMixBatch: Label's must be in one-hot format and in a batch"
assert
error_message
in
str
(
error
.
value
)
if
__name__
==
"__main__"
:
test_cutmix_batch_success1
(
plot
=
True
)
test_cutmix_batch_success2
(
plot
=
True
)
test_cutmix_batch_nchw_md5
()
test_cutmix_batch_nhwc_md5
()
test_cutmix_batch_fail1
()
test_cutmix_batch_fail2
()
test_cutmix_batch_fail3
()
test_cutmix_batch_fail4
()
test_cutmix_batch_fail5
()
test_cutmix_batch_fail6
()
test_cutmix_batch_fail7
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录