提交 40f0a4a4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!333 Add skip op to Dataset

Merge pull request !333 from jiangzhiwen/dataset/skip
......@@ -47,6 +47,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kMap, &DEPipeline::ParseMapOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
......@@ -511,6 +512,17 @@ Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp
return Status::OK();
}
Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<SkipOp> op;
RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
for (auto arg : args) {
......
......@@ -42,6 +42,7 @@ enum OpName {
kBatch,
kCache,
kRepeat,
kSkip,
kTake,
kZip,
kMap,
......@@ -107,6 +108,8 @@ class DEPipeline {
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......
......@@ -446,6 +446,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
.value("SKIP", OpName::kSkip)
.value("TAKE", OpName::kTake)
.value("ZIP", OpName::kZip)
.value("MAP", OpName::kMap)
......
......@@ -32,6 +32,7 @@
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
......
......@@ -11,6 +11,7 @@ add_library(engine-datasetops OBJECT
project_op.cc
rename_op.cc
repeat_op.cc
skip_op.cc
shuffle_op.cc
zip_op.cc
)
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <utility>
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
Status SkipOp::Builder::SanityCheck() const {
if (build_max_skips_ < 0) {
std::string err_msg("Skip count must be positive integer or 0.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
// The builder "build" method creates the final object.
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<SkipOp>(build_max_skips_);
return Status::OK();
}
// Constructor of the SkipOp.
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
// Destructor
SkipOp::~SkipOp() {}
// A print method typically used for debugging
void SkipOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
// Then display our own stuff
out << "SkipOp:"
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
}
// Since the buffer may contain multi rows, this function will drop the rows
// that need to skip in it, and then return the buffer.
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
// Drop first max_skips_ rows
while (skip_count_ < max_skips_) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
if (buf->eoe() || buf->eof()) {
break;
}
// Consider the rows of buffer more than 1
TensorRow drop_row;
int row_num = buf->NumRows();
for (int i = 0; i < row_num; i++) {
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
if (++skip_count_ == max_skips_) {
break;
}
}
}
// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if (!buf || buf->NumRows() == 0) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
// Handling eoe and eof
if (buf->eoe() || buf->eof()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
if (state_ == OpState::kDeOpIdle) {
*p_buffer = std::move(buf);
return Status::OK();
}
}
*p_buffer = std::move(buf);
return Status::OK();
}
// Base-class override for handling cases when an eoe is received.
Status SkipOp::EoeReceived(int32_t worker_id) {
skip_count_ = 0;
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to
// launch the functor since this op runs inlined inside another operator. The
// function is overloaded to ensure that it is not called by mistake (it will
// generate an error).
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
// Base-class override for handling cases when an eof is received.
Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG(INFO) << "Skip operator EOF received, do nothing now.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class SkipOp : public PipelineOp {
public:
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of skip to do
// @return This is a constructor.
explicit Builder(int32_t count);
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status Build(std::shared_ptr<SkipOp> *);
private:
int32_t build_max_skips_;
Status SanityCheck() const;
};
// Constructor of the SkipOp.
// @note The builder class should be used to call it
// @param count - The number of skips to do
explicit SkipOp(int32_t count);
// Destructor
~SkipOp();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
Status operator()() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
// a buffer from our child.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override;
// Base-class override for handling cases when an eof is received.
// @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override;
private:
int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
......@@ -35,7 +35,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_zip, check_rename, \
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column
......@@ -423,6 +423,25 @@ class Dataset:
return self
return RepeatDataset(self, count)
@check_skip
def skip(self, count):
"""
Skip the first N elements of this dataset.
Args:
count (int): Number of elements the dataset should be skipped.
Returns:
SkipDataset, dataset skipped.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> # creates a dataset which skips first 3 elements from data
>>> data = data.skip(3)
"""
return SkipDataset(self, count)
@check_zip_dataset
def zip(self, datasets):
"""
......@@ -1081,6 +1100,39 @@ class RepeatDataset(DatasetOp):
"""
return self.count
class SkipDataset(DatasetOp):
"""
The result of applying Skip operator to the input Dataset.
Args:
datasets (tuple): A tuple of datasets to be skipped.
count (int): Number of rows the dataset should be skipped.
"""
def __init__(self, input_dataset, count):
super().__init__()
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
args = super().get_args()
args["count"] = self.count
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
output_size = 0
if self.count >= 0 and self.count < child_size:
output_size = child_size - self.count
return output_size
class ZipDataset(DatasetOp):
"""
......
......@@ -127,6 +127,8 @@ class Iterator:
op_type = OpName.MAP
elif isinstance(dataset, de.RepeatDataset):
op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset):
op_type = OpName.SKIP
elif isinstance(dataset, de.StorageDataset):
op_type = OpName.STORAGE
elif isinstance(dataset, de.ImageFolderDatasetV2):
......
......@@ -297,6 +297,9 @@ def create_node(node):
elif dataset_op == 'RepeatDataset':
pyobj = de.Dataset().repeat(node.get('count'))
elif dataset_op == 'SkipDataset':
pyobj = de.Dataset().skip(node.get('count'))
elif dataset_op == 'MapDataset':
tensor_ops = construct_tensor_ops(node.get('operations'))
pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),
......
......@@ -709,6 +709,20 @@ def check_repeat(method):
return new_method
def check_skip(method):
"""check the input arguments of skip."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
count = param_dict.get('count')
check_type(count, 'count', int)
if count < 0:
raise ValueError("Skip count must be positive integer or 0.")
return method(*args, **kwargs)
return new_method
def check_zip(method):
"""check the input arguments of zip."""
......
......@@ -41,6 +41,7 @@ SET(DE_UT_SRCS
random_vertical_flip_op_test.cc
rename_op_test.cc
repeat_op_test.cc
skip_op_test.cc
rescale_op_test.cc
resize_bilinear_op_test.cc
resize_op_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/circular_pool.h"
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestSkipOp : public UT::DatasetOpTesting {};
TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetWorkerConnectorSize(16)
.SetNumWorkers(16);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {});
builder.SetDataSchema(std::move(schema));
Status rc = builder.Build(&my_tfreader_op); ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
// SkipOp
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
rc = my_tree->AssociateNode(skip_op);
ASSERT_TRUE(rc.IsOk());
// Set children/root layout.
rc = skip_op->AddChild(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(skip_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 7);
}
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_tf_skip():
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.skip(2)
num_iter = 0
for item in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 1
def generator_md():
# Create a dataset with [0, 1, 2, 3, 4]
for i in range(5):
yield (np.array([i]), )
def test_generator_skip():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [3, 4]
ds1 = ds1.skip(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
def test_skip_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be []
ds1 = ds1.skip(7)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 0
def test_skip_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4]
ds1 = ds1.skip(0)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 5
def test_skip_repeat_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.repeat(2)
# Here ds1 should be [3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.skip(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 7
def test_skip_repeat_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [3, 4]
ds1 = ds1.skip(3)
# Here ds1 should be [3, 4, 3, 4]
ds1 = ds1.repeat(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 4
def test_skip_repeat_3():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1 = ds1.repeat(2)
# Here ds1 should be [3, 4]
ds1 = ds1.skip(8)
# Here ds1 should be [3, 4, 3, 4, 3, 4]
ds1 = ds1.repeat(3)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 6
if __name__ == "__main__":
test_tf_skip()
test_generator_skip()
test_skip_1()
test_skip_2()
test_skip_repeat_1()
test_skip_repeat_2()
test_skip_repeat_3()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册