提交 3ad73b7d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!463 !330 dataset: add take operation

Merge pull request !463 from ms_yan/take_op_merge
......@@ -54,6 +54,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
......@@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp
return Status::OK();
}
DsOpPtr DEPipeline::ParseTakeOp(const py::dict &args) const { return DsOpPtr(); }
Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<TakeOp> op;
RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>();
......
......@@ -116,7 +116,7 @@ class DEPipeline {
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
DsOpPtr ParseTakeOp(const py::dict &args) const;
Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......
......@@ -38,6 +38,7 @@
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
......
......@@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT
parallel_op.cc
pipeline_op.cc
batch_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
project_op.cc
rename_op.cc
repeat_op.cc
skip_op.cc
take_op.cc
shuffle_op.cc
zip_op.cc
)
......
......@@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if (!buf || buf->NumRows() == 0) {
if (buf && buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "common/utils.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) {
std::string err_msg("Take count must be greater than 0.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
// The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<TakeOp>(build_max_takes_);
return Status::OK();
}
// Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
// A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
// Then display our own stuff
out << "TakeOp:"
<< "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_;
}
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(INFO) << "meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
} else {
MS_LOG(INFO) << "meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
take_count_ = 0;
}
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
}
// Check if the last buf is next eof
if (buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
}
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
}
return Status::OK();
}
// Function FillBuffer mainly prepare the buffer for returning
Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer) {
int32_t buffer_size = (*buffer)->NumRows();
if (take_count_ + buffer_size < max_takes_) {
*data_buffer = std::move(*buffer);
take_count_ = take_count_ + buffer_size;
} else {
MS_LOG(INFO) << "In last buffer: Push one buffer.";
std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
while (take_count_ < max_takes_) {
TensorRow new_row;
RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row));
take_count_++;
new_tensor_table->push_back(new_row);
}
(*buffer)->set_tensor_table(std::move(new_tensor_table));
*data_buffer = std::move(*buffer);
}
return Status::OK();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToRepeatStack(shared_from_this());
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class TakeOp : public PipelineOp {
public:
// The nested builder class inside of the TakeOp is used to help manage all of the arguments
// for constructing it. This take op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of takes to do
// @return This is a constructor.
explicit Builder(int32_t count);
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status Build(std::shared_ptr<TakeOp> *);
private:
int32_t build_max_takes_;
Status SanityCheck() const;
};
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @param count - The number of takes to do
explicit TakeOp(int32_t count);
// Destructor
~TakeOp() = default;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param ro - reference to the TakeOp to display
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) {
ro.Print(out, false);
return out;
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
Status operator()() override;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes
Status FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer);
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
......@@ -36,7 +36,7 @@ from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
......@@ -442,6 +442,33 @@ class Dataset:
"""
return SkipDataset(self, count)
@check_take
def take(self, count=-1):
"""
Takes at most given numbers of elements from the dataset.
Note:
1. If count is greater than the number of element in dataset or equal to -1,
all the element in dataset will be taken.
2. The order of using take and batch effects. If take before batch operation,
then taken given number of rows, otherwise take given number of batches.
Args:
count (int, optional): Number of elements to be taken from the dataset (default=-1).
Returns:
TakeDataset, dataset taken.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> # creates a dataset where the dataset including 50 elements.
>>> data = data.take(50)
"""
if count == -1:
return self
return TakeDataset(self, count)
@check_zip_dataset
def zip(self, datasets):
"""
......@@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp):
"""
return self.count
class SkipDataset(DatasetOp):
"""
The result of applying Skip operator to the input Dataset.
......@@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp):
output_size = child_size - self.count
return output_size
class TakeDataset(DatasetOp):
"""
The result of applying Take operator to the input Dataset.
Args:
input_dataset (Dataset): Input Dataset to be taken element from.
count (int): Number of elements to be taken from the dataset.
"""
def __init__(self, input_dataset, count):
super().__init__()
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
args = super().get_args()
args["count"] = self.count
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
if child_size < self.count:
return child_size
return self.count
class ZipDataset(DatasetOp):
"""
The result of applying Zip operator to the input Dataset.
......
......@@ -129,6 +129,8 @@ class Iterator:
op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset):
op_type = OpName.SKIP
elif isinstance(dataset, de.TakeDataset):
op_type = OpName.TAKE
elif isinstance(dataset, de.StorageDataset):
op_type = OpName.STORAGE
elif isinstance(dataset, de.ImageFolderDatasetV2):
......
......@@ -304,6 +304,9 @@ def create_node(node):
elif dataset_op == 'SkipDataset':
pyobj = de.Dataset().skip(node.get('count'))
elif dataset_op == 'TakeDataset':
pyobj = de.Dataset().take(node.get('count'))
elif dataset_op == 'MapDataset':
tensor_ops = construct_tensor_ops(node.get('operations'))
pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),
......
......@@ -602,7 +602,7 @@ def check_batch_size(batch_size):
def check_count(count):
check_type(count, 'count', int)
if (count <= 0 and count != -1) or count > INT32_MAX:
raise ValueError("repeat count should be either -1 or positive integer.")
raise ValueError("count should be either -1 or positive integer.")
def check_columns(columns, name):
......@@ -709,6 +709,7 @@ def check_repeat(method):
return new_method
def check_skip(method):
"""check the input arguments of skip."""
@wraps(method)
......@@ -724,6 +725,21 @@ def check_skip(method):
return new_method
def check_take(method):
"""check the input arguments of take."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
count = param_dict.get('count')
check_count(count)
return method(*args, **kwargs)
return new_method
def check_zip(method):
"""check the input arguments of zip."""
@wraps(method)
......@@ -759,6 +775,7 @@ def check_zip_dataset(method):
return new_method
def check_rename(method):
"""check the input arguments of rename."""
@wraps(method)
......
......@@ -64,6 +64,7 @@ SET(DE_UT_SRCS
voc_op_test.cc
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "common/common.h"
#include "common/utils.h"
#include "dataset/core/client.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestTakeOp : public UT::DatasetOpTesting {};
TEST_F(MindDataTestTakeOp, TestTakeProject) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
// TFReaderOp
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetWorkerConnectorSize(16)
.SetNumWorkers(16);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {});
builder.SetDataSchema(std::move(schema));
Status rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
// TakeOp
std::shared_ptr<TakeOp> my_take_op;
TakeOp::Builder builder_take(5);
rc = builder_take.Build(&my_take_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_take_op);
ASSERT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_take_op->AddChild(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_take_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 5);
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
import numpy as np
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def generator():
for i in range(3):
yield np.array([i]),
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def generator_10():
for i in range(10):
yield np.array([i]),
def test_take_01():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
"""
logger.info("test_take_01")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 0 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_02():
"""
Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
"""
logger.info("test_take_02")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 2 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_03():
"""
Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
"""
logger.info("test_take_03")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(3)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_04():
"""
Test take: origin there are 3 row, and take 4 row, this is more than the total rows
"""
logger.info("test_take_04")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(4)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_05():
"""
Test take: there is no repeat op
"""
logger.info("test_take_05")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_06():
"""
Test take: repeat is before take
"""
logger.info("test_take_06")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.repeat(2)
data1 = data1.take(4)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_07():
"""
Test take: take is before batch, that mean take(N), N refer to rows num
"""
logger.info("test_take_07")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.batch(2)
assert sum([1 for _ in data1]) == 1
def test_take_08():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger.info("test_take_08")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.take(2)
assert sum([1 for _ in data1]) == 2
def test_take_09():
"""
Test take: repeat count is -1, and read the whole dataset, take after repeat
"""
logger.info("test_take_09")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.repeat(2)
data1 = data1.take(-1)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_10():
"""
Test take: repeat count is -1, and read the whole dataset, take before repeat
"""
logger.info("test_take_10")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i % 3 == d[0][0]
assert sum([1 for _ in data1]) == 6
def test_take_11():
"""
Test take: batch first, then do repeat and take operation
"""
logger.info("test_take_11")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.repeat(2)
data1 = data1.take(-1)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 * (i % 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_12():
"""
Test take: take first, then do batch and repeat operation
"""
logger.info("test_take_12")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 0 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_13():
"""
Test take: skip first, then do take, batch and repeat operation
"""
logger.info("test_take_13")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.skip(2)
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_14():
"""
Test take: take first, then do batch, skip and repeat operation
"""
logger.info("test_take_14")
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.take(-1)
data1 = data1.batch(2)
data1 = data1.skip(1)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
def test_take_15():
"""
Test take: large amount data, take a part, then do skip operation
"""
logger.info("test_take_15")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(6)
data1 = data1.skip(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert (i + 2) == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_16():
"""
Test take: large amount data, skip a part, then do take operation
"""
logger.info("test_take_16")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.skip(3)
data1 = data1.take(5)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert (i + 3) == d[0][0]
assert sum([1 for _ in data1]) == 5
if __name__ == '__main__':
test_take_01()
test_take_02()
test_take_03()
test_take_04()
test_take_05()
test_take_06()
test_take_07()
test_take_08()
test_take_09()
test_take_10()
test_take_11()
test_take_12()
test_take_13()
test_take_14()
test_take_15()
test_take_16()
logger.info('== test take operation finished ==')
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册