提交 3a1d8b5f 编写于 作者: E ervinzhang

added channel swap operation in transforms.h

上级 9ad82f79
......@@ -21,6 +21,7 @@
#include "minddata/dataset/kernels/image/crop_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
......@@ -83,6 +84,16 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) {
return op;
}
// Function to create HwcToChwOperation.
std::shared_ptr<HwcToChwOperation> HWC2CHW() {
auto op = std::make_shared<HwcToChwOperation>();
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create MixUpBatchOperation.
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
auto op = std::make_shared<MixUpBatchOperation>(alpha);
......@@ -293,6 +304,11 @@ bool DecodeOperation::ValidateParams() { return true; }
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
// HwcToChwOperation
bool HwcToChwOperation::ValidateParams() { return true; }
std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<HwcToChwOp>(); }
// MixUpOperation
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
......
......@@ -51,6 +51,7 @@ class CenterCropOperation;
class CropOperation;
class CutOutOperation;
class DecodeOperation;
class HwcToChwOperation;
class MixUpBatchOperation;
class NormalizeOperation;
class OneHotOperation;
......@@ -92,6 +93,11 @@ std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1)
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
/// \brief Function to create a HwcToChw TensorOperation.
/// \notes Transpose the input image; shape (H, W, C) to shape (C, H, W).
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<HwcToChwOperation> HWC2CHW();
/// \brief Function to create a MixUpBatch TensorOperation.
/// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and
/// Batch must be called before calling this function.
......@@ -273,6 +279,15 @@ class DecodeOperation : public TensorOperation {
bool rgb_;
};
class HwcToChwOperation : public TensorOperation {
public:
~HwcToChwOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
};
class MixUpBatchOperation : public TensorOperation {
public:
explicit MixUpBatchOperation(float alpha = 1);
......
......@@ -463,6 +463,55 @@ TEST_F(MindDataTestPipeline, TestDecode) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestHwcToChw) {
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> channel_swap = vision::HWC2CHW();
EXPECT_NE(channel_swap, nullptr);
// Create a Map operation on ds
ds = ds->Map({channel_swap});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
// check if the image is in NCHW
EXPECT_EQ(batch_size == image->shape()[0] && 3 == image->shape()[1]
&& 2268 == image->shape()[2] && 4032 == image->shape()[3], true);
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册