提交 2e3d55ed 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1281 Implementation of SplitOp

Merge pull request !1281 from Peilin/splitOp
......@@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
std::string err_msg = "Error: Shuffle buffer size is missing";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "reshuffle_each_epoch") {
(void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"]));
}
}
}
std::shared_ptr<ShuffleOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
......
......@@ -51,6 +51,7 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
......@@ -425,11 +426,14 @@ void bindSamplerOps(py::module *m) {
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices", [](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
});
.def("get_indices",
[](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
......@@ -441,12 +445,16 @@ void bindSamplerOps(py::module *m) {
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
.def(py::init<bool>(), py::arg("replacement"));
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
......
......@@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)
......@@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
} else if (cnt_ == samples_per_buffer_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
while (cnt_ < samples_per_buffer_) {
int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_;
*(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast<size_t>(next_id)] : next_id;
int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
if (shuffle_) {
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
cnt_++;
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
......@@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
Status DistributedSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0;
rnd_.seed(seed_++);
if (shuffle_ == true) {
rnd_.seed(seed_);
seed_++;
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): DistributedSampler\n";
if (show_all) {
out << "seed_: " << seed_ << '\n';
out << "device_id_: " << device_id_ << '\n';
out << "num_devices_: " << num_devices_ << '\n';
out << "shuffle_: " << shuffle_ << '\n';
}
}
} // namespace dataset
} // namespace mindspore
......@@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return
Status Reset() override;
void Print(std::ostream &out, bool show_all) const override;
private:
int64_t cnt_; // number of samples that have already been filled in to buffer
uint32_t seed_;
......
......@@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
rnd_.seed(seed_++);
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
num_samples_ = num_pk_samples_;
if (shuffle_ == true) {
std::shuffle(labels_.begin(), labels_.end(), rnd_);
} else {
......@@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_pk_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
int64_t last_id =
......@@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
int64_t cls_id = next_id_++ / samples_per_class_;
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_);
*(id_ptr++) = samples[rnd_ind];
int64_t sampled_id = samples[rnd_ind];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
......@@ -75,6 +88,11 @@ Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
......
......@@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
std::shared_ptr<Tensor> sample_ids;
{
py::gil_scoped_acquire gil_acquire;
......@@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
py::object py_ret = py_sampler_instance.attr("_get_indices")();
py::array np_sample_ids = py_ret.cast<py::array>();
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
if (HasChildSampler()) {
for (auto it = sample_ids->begin<int64_t>(); it != sample_ids->end<int64_t>(); ++it) {
int64_t associated_child_id = 0;
RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id));
*it = associated_child_id;
}
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) {
......@@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
} // namespace dataset
......
......@@ -14,18 +14,22 @@
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "dataset/util/random.h"
namespace mindspore {
namespace dataset {
RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer)
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
seed_(GetSeed()),
replacement_(replacement),
user_num_samples_(num_samples),
next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
......@@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_;
int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
for (int64_t i = 0; i < (last_id - next_id_); i++) {
*(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast<size_t>(i + next_id_)];
int64_t sampled_id = 0;
if (replacement_) {
sampled_id = (*dist)(rnd_);
} else {
sampled_id = shuffled_ids_[static_cast<size_t>(i + next_id_)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(id_ptr + i) = sampled_id;
}
next_id_ = last_id;
TensorRow row(1, sampleIds);
......@@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
rnd_.seed(seed_++);
rnd_.seed(seed_);
if (replacement_ == false) {
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
......@@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
Status RandomSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);
if (replacement_ == false) {
if (reshuffle_each_epoch_) {
seed_++;
}
rnd_.seed(seed_);
if (replacement_ == false && reshuffle_each_epoch_) {
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void RandomSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): RandomSampler\n";
if (show_all) {
out << "user_num_samples_: " << user_num_samples_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
out << "next_id_: " << next_id_ << '\n';
}
}
} // namespace dataset
} // namespace mindspore
......@@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits<int64_t>::max(),
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
int64_t num_samples = std::numeric_limits<int64_t>::max(),
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
......@@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
// @return - The error code return
Status Reset() override;
virtual void Print(std::ostream &out, bool show_all) const;
private:
uint32_t seed_;
bool replacement_;
......@@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
int64_t next_id_;
std::mt19937 rnd_;
std::unique_ptr<std::uniform_int_distribution<int64_t>> dist;
bool reshuffle_each_epoch_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -15,18 +15,41 @@
*/
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include <string>
namespace mindspore {
namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
if (HasChildSampler()) {
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
if (!child_sampler) {
std::string err_msg("Cannot handshake, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
if (HasChildSampler()) {
int64_t child_num_samples = child_sampler->num_samples();
num_rows_ = child_num_samples;
} else {
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
}
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
return Status::OK();
}
......@@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return Status::OK();
}
void Sampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): base\n";
if (show_all) {
out << "num_rows_: " << num_rows_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
}
}
Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;
......@@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
num_rows_ = num_rows;
return Status::OK();
}
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
if (!sampler) {
std::string err_msg("Cannot add child, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Samplers can have at most 1 child.
if (!child_.empty()) {
std::string err_msg("Cannot add child sampler, this sampler already has a child.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
child_.push_back(child);
// doesn't work, protected?
// child->AddParent(this);
return Status::OK();
}
bool Sampler::HasChildSampler() { return !child_.empty(); }
Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
if (child_ids_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
}
TensorRow sample_row;
RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
// setter function for num_samples_
Status SetNumSamples(int64_t num_samples);
int64_t num_samples() { return num_samples_; }
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return
......@@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status AddChild(std::shared_ptr<DatasetOp> child);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number
// @return
// @return - The error code returned.
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
sampler.Print(out, false);
return out;
}
// Checks if this sampler has a child sampler.
// @return - tre if there is a child sampler, false otherwise.
bool HasChildSampler();
// Uses id as an index for the list of ids generated by the child sampler, and gets the
// associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
protected:
// Number of rows of data from the place this sampler is sampling from. If this sampler
// has a child sampler, num_rows_ is the number of ids the child sampler will
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t num_rows_;
// Number of ids this sampler will return.
int64_t num_samples_;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::unique_ptr<DataBuffer> child_ids_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -15,6 +15,7 @@
*/
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include <algorithm>
#include <memory>
namespace mindspore {
......@@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
while (next_id_ < lastId) {
*(idPtr++) = next_id_++;
int64_t sampled_id = next_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*idPtr = sampled_id;
next_id_++;
idPtr++;
}
TensorRow row(1, sampleIds);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
......@@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
if (HasChildSampler()) {
num_samples_ = std::min(num_samples_, num_rows_);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
......@@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
Status SequentialSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
} // namespace dataset
} // namespace mindspore
......@@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
private:
int64_t next_id_;
};
......
......@@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
Status SubsetRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
num_samples_ = indices_.size();
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
......@@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
rand_gen_.seed(GetSeed());
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
......@@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
if (sample_id_ == indices_.size()) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds;
......@@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED(err_msg);
}
*(id_ptr++) = indices_[sample_id_++];
int64_t sampled_id = indices_[sample_id_];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
sample_id_++;
}
// Create a TensorTable from that single tensor and push into DataBuffer
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace mindspore {
namespace dataset {
// Constructor.
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
return Status::OK();
}
Status SubsetSampler::Reset() {
current_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (current_id_ > subset_size_) {
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
} else if (current_id_ == subset_size_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampled_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
while (current_id_ < subset_size_) {
int64_t sampled_id = start_index_ + current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(sampled_ids_start_addr + current_id_) = sampled_id;
current_id_++;
}
TensorRow sampled_ids_row(1, sampled_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace mindspore {
namespace dataset {
class SubsetSampler : public Sampler {
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
// Destructor.
~SubsetSampler() = default;
// Initialize the sampler.
// @return Status
Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status Reset() override;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
int64_t start_index_;
int64_t subset_size_;
int64_t current_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
......@@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
Status WeightedRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
num_samples_ = user_num_samples_;
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
......@@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
} else {
discrete_dist_->reset();
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
......@@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
if (sample_id_ == user_num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds;
......@@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
}
*(id_ptr++) = genId;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
}
*id_ptr = genId;
id_ptr++;
sample_id_++;
}
......
......@@ -47,6 +47,7 @@ class Sampler:
def __init__(self):
self.dataset_size = 0
self.num_samples = 0
self.child_sampler = None
def __iter__(self):
"""
......@@ -83,7 +84,35 @@ class Sampler:
# Instance fetcher
# Do not override this method!
def create(self):
return cde.PythonSampler(self)
c_sampler = cde.PythonSampler(self)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def create_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
return c_child_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class BuiltinSampler:
......@@ -93,11 +122,30 @@ class BuiltinSampler:
User should not extend this class.
"""
def __init__(self):
pass
self.child_sampler = None
def create(self):
pass
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def create_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
return c_child_sampler
def is_shuffled(self):
raise NotImplementedError("Sampler must implement is_shuffled.")
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
class DistributedSampler(BuiltinSampler):
"""
......@@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler):
def create(self):
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return self.num_shards > 1
return self.child_sampler.is_sharded()
class PKSampler(BuiltinSampler):
......@@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.PKSampler(self.num_val, self.shuffle)
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str):
......@@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler):
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
self.deterministic = False
self.replacement = replacement
self.num_samples = num_samples
self.reshuffle_each_epoch = True
super().__init__()
def create(self):
# If num_samples is not specified, then call constructor #2
c_sampler = None
if self.num_samples is None:
return cde.RandomSampler(self.replacement)
return cde.RandomSampler(self.replacement, self.num_samples)
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
else:
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SequentialSampler(BuiltinSampler):
......@@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler):
"""
def create(self):
return cde.SequentialSampler()
c_sampler = cde.SequentialSampler()
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SubsetSampler(BuiltinSampler):
"""
Samples a subset of elements consecutively from a given index.
Args:
start_index (int): Index to start sampling at.
subset_size (int): How many samples to include in this subset.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
>>> sampler = ds.SubsetSampler(100, 5)
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
ValueError: If start_index is not a positive int.
ValueError: If subset_size is not a positive int.
"""
def __init__(self, start_index, subset_size):
if not isinstance(start_index, int):
raise ValueError("start_index should be an int.")
if start_index < 0:
raise ValueError("start_index should not be negative.")
if not isinstance(subset_size, int):
raise ValueError("start_index should be an int")
if subset_size < 0:
raise ValueError("subset_size should not be negative.")
self.start_index = start_index
self.subset_size = subset_size
super().__init__()
def create(self):
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SubsetRandomSampler(BuiltinSampler):
......@@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.SubsetRandomSampler(self.indices)
c_sampler = cde.SubsetRandomSampler(self.indices)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
......@@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
......@@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
return method(*args, **kwargs)
return new_method
def check_split(method):
"""check the input arguments of split."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_list = ['sizes']
nreq_param_bool = ['randomize']
check_param_type(nreq_param_list, param_dict, list)
check_param_type(nreq_param_bool, param_dict, bool)
# check sizes: must be list of float or list of int
sizes = param_dict.get('sizes')
if not sizes:
raise ValueError("sizes cannot be empty.")
all_int = all(isinstance(item, int) for item in sizes)
all_float = all(isinstance(item, float) for item in sizes)
if not (all_int or all_float):
raise ValueError("sizes should be list of int or list of float.")
if all_int:
all_positive = all(item > 0 for item in sizes)
if not all_positive:
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
if all_float:
all_valid_percentages = all(0 < item <= 1 for item in sizes)
if not all_valid_percentages:
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
epsilon = 0.00001
if not abs(sum(sizes) - 1) < epsilon:
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
return method(*args, **kwargs)
return new_method
......@@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
uint32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
tree->Prepare();
......
......@@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
int32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
......
......@@ -164,9 +164,36 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_shards, shard_id):
sampler = ds.DistributedSampler(num_shards, shard_id, False)
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
res = []
for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
assert test_config(2, 0) == [0, 2, 4]
assert test_config(2, 1) == [1, 3, 0]
assert test_config(5, 0) == [0]
assert test_config(5, 1) == [1]
assert test_config(5, 2) == [2]
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()
test_sampler_chain()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def split_with_invalid_inputs(d):
with pytest.raises(ValueError) as info:
s1, s2 = d.split([])
assert "sizes cannot be empty" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([5, 0.6])
assert "sizes should be list of int or list of float" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([-1, 6])
assert "there should be no negative numbers" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([3, 1])
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([5, 1])
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([-0.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([1.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.5, 0.6])
assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.3, 0.6])
assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.05, 0.95])
assert "percentage 0.05 is too small" in str(info.value)
def test_unmappable_invalid_input():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
d = ds.TextFileDataset(text_file_dataset_path)
split_with_invalid_inputs(d)
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
ds.config.set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:4]
assert s2_output == text_file_data[4:]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:4]
assert s2_output == text_file_data[4:]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file)
split_with_invalid_inputs(d)
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general():
d = ds.ManifestDataset(manifest_file, shuffle=False)
d = d.take(5)
# absolute rows
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_split_optimized():
d = ds.ManifestDataset(manifest_file, shuffle=False)
# absolute rows
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_randomize_deterministic():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
for _ in range(10):
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# note no overlap
assert s1_output == [0, 1, 3, 4]
assert s2_output == [2]
def test_mappable_randomize_repeatable():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
num_epochs = 5
s1 = s1.repeat(num_epochs)
s2 = s2.repeat(num_epochs)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# note no overlap
assert s1_output == [0, 1, 3, 4] * num_epochs
assert s2_output == [2] * num_epochs
def test_mappable_sharding():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
num_epochs = 5
first_split_num_rows = 4
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([first_split_num_rows, 1])
distributed_sampler = ds.DistributedSampler(2, 0)
s1.use_sampler(distributed_sampler)
s1 = s1.repeat(num_epochs)
# testing sharding, second dataset to simulate another instance
d2 = ds.ManifestDataset(manifest_file, shuffle=False)
d2s1, d2s2 = d2.split([first_split_num_rows, 1])
distributed_sampler = ds.DistributedSampler(2, 1)
d2s1.use_sampler(distributed_sampler)
d2s1 = d2s1.repeat(num_epochs)
# shard 0
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# shard 1
d2s1_output = []
for item in d2s1.create_dict_iterator():
d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
rows_per_shard_per_epoch = 2
assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
# verify each epoch that
# 1. shards contain no common elements
# 2. the data was split the same way, and that the union of shards equal the split
correct_sorted_split_result = [0, 1, 3, 4]
for i in range(num_epochs):
combined_data = []
for j in range(rows_per_shard_per_epoch):
combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
assert sorted(combined_data) == correct_sorted_split_result
# test other split
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
d2s2_output = []
for item in d2s2.create_dict_iterator():
d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
assert d2s2_output == [2]
if __name__ == '__main__':
test_unmappable_invalid_input()
test_unmappable_split()
test_mappable_invalid_input()
test_mappable_split_general()
test_mappable_split_optimized()
test_mappable_randomize_deterministic()
test_mappable_randomize_repeatable()
test_mappable_sharding()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册