提交 30de261c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!243 Support nested repeat

Merge pull request !243 from h.farahat/nested_repeat
......@@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodeAction() {
Status DatasetOp::PrepareNodePreAction() {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
return Status::OK();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePostAction() {
// If this op does not have any children and it is in a repeat path of the tree...
if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) {
// Then, flag this operator as a leaf node in a repeat path of tree execution.
BitSet(&op_ctrl_flags_, kDeOpRepeated);
// Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
// push ourselves onto the tree repeat stack. Later, the repeat operator
// above us will consume them.
tree_->AddToRepeatStack(shared_from_this());
}
......
......@@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return Status::OK();
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual Status PrepareNodeAction();
virtual Status PrepareNodePreAction();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual Status PrepareNodePostAction();
// Getter function
// @return The operator id
......
......@@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp {
return out;
}
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status PrepareNodeAction() override {
Status PrepareNodePreAction() override {
// Run common code from super class before adding ParallelOp specific logic
return (DatasetOp::PrepareNodeAction());
return (DatasetOp::PrepareNodePreAction());
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status PrepareNodePostAction() override {
// Run common code from super class before adding ParallelOp specific logic
return (DatasetOp::PrepareNodePostAction());
}
// Override base class reset to provide reset actions specific to the ParallelOp class.
......
......@@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp {
// @return The number of threads that push data to the output connector
int32_t num_producers() const override { return 1; }
// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodeAction() override {
Status PrepareNodePreAction() override {
// Run common code from super class before adding PipelineOp specific logic
return (DatasetOp::PrepareNodeAction());
return (DatasetOp::PrepareNodePreAction());
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override {
// Run common code from super class before adding PipelineOp specific logic
return (DatasetOp::PrepareNodePostAction());
}
protected:
......
......@@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
out << "RepeatOp:"
<< "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
<< "\nLeaf Nodes in my execution path:";
if (!leaf_ops_.empty()) {
if (!eoe_ops_.empty()) {
out << "\n";
for (size_t i = 0; i < leaf_ops_.size(); i++) {
out << " Operator: " << leaf_ops_[i]->id() << "\n";
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << " Operator: " << eoe_ops_[i]->id() << "\n";
}
} else {
out << " kNone.";
......@@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status RepeatOp::PrepareNodeAction() {
Status RepeatOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction());
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack();
while (leaf_op != nullptr) {
// Track the leaf operators that are under this repeat op.
leaf_ops_.push_back(leaf_op);
// Special case. If the repeat count is 1, then pre-flag the leaf nodes
// to tell them they are already at their last op:
if (max_repeats_ == 1) {
leaf_op->set_control_flag(kDeOpLastRepeat);
}
eoe_ops_.push_back(leaf_op);
leaf_op = tree_->PopFromRepeatStack();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_->AddToRepeatStack(shared_from_this());
return Status::OK();
}
......@@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
Status RepeatOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
// If we've reached the requested repeat count, then flag the leaf nodes
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
// If we've reached the requested repeat count, then flag the eoe nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (size_t i = 0; i < leaf_ops_.size(); i++) {
leaf_ops_[i]->set_control_flag(kDeOpLastRepeat);
// of the last epoch, they quit rather than loop again. This happens in two cases:
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
// 2- We are not repeated
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
for (auto &eoe_op : eoe_ops_) {
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
if (repeat_count_ == max_repeats_) {
repeat_count_ = 0;
state_ = OpState::kDeOpIdle;
return Status::OK();
}
......
......@@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp {
uint32_t PrepareFlags() const override;
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status PrepareNodeAction() override;
// during the execution tree post-prepare phase when it is visiting this operator.
Status PrepareNodePostAction() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
......@@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp {
int32_t num_producers() const override;
private:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> leaf_ops_; // List of leaf operators underneath this repeat.
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
};
} // namespace dataset
} // namespace mindspore
......
......@@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() {
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
int32_t num_children = dataset_op->child_.size();
// execute PreAction
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());
// Before going down into children, make any prepare flags updates based on this
// operator.
// Before going down into children, make any prepare flags updates based on this operator.
uint32_t op_prep_flags = dataset_op->PrepareFlags();
// Sanity check. In future we can support nested repeats. for now it's not allowed.
// If somebody above us already set the repeat flag, and now we are another repeat...
if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) {
std::string err_msg("Nested RepeatOp detected! This is not supported yet.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
BitSet(&prepare_flags_, op_prep_flags);
// Now, descend to children
for (int32_t i = 0; i < num_children; ++i) {
RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i]));
for (const auto &i : dataset_op->child_) {
RETURN_IF_NOT_OK(this->PrepareNode(i));
}
// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function exit
RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction());
// Then clear the flags from this op now that we have prepared it.
BitClear(&prepare_flags_, op_prep_flags);
// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function
RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction());
return Status::OK();
}
......
......@@ -419,6 +419,8 @@ class Dataset:
>>> repeat_and_shuffle = data.repeat(50)
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
"""
if count == 1:
return self
return RepeatDataset(self, count)
@check_zip_dataset
......
......@@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
auto my_tree = std::make_shared<ExecutionTree>();
std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32);
std::shared_ptr<DatasetOp> leaf_op = std::make_shared<RepeatOp>(16);
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
// TFReaderOp
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetWorkerConnectorSize(16)
.SetNumWorkers(16);
Status rc= builder.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(parent_op);
my_tree->AssociateNode(leaf_op);
ASSERT_NE(parent_op, nullptr);
ASSERT_NE(leaf_op, nullptr);
parent_op->AddChild(std::move(leaf_op));
parent_op->Print(std::cout, false);
parent_op->PrepareNodeAction();
ASSERT_NE(my_tfreader_op, nullptr);
parent_op->AddChild(std::move(my_tfreader_op));
MS_LOG(INFO) << parent_op;
my_tree->Prepare();
RepeatOp RepeatOpOp();
std::shared_ptr<RepeatOp> repeat_op;
Status rc = RepeatOp::Builder(3).Build(&repeat_op);
rc = RepeatOp::Builder(3).Build(&repeat_op);
ASSERT_NE(repeat_op, nullptr);
}
......@@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
from util import save_and_check
import mindspore.dataset as ds
import numpy as np
from mindspore import log as logger
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
......@@ -95,6 +96,141 @@ def test_tf_repeat_03():
assert num_iter == 2
def generator():
for i in range(3):
yield np.array([i]),
def test_nested_repeat1():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
assert sum([1 for _ in data]) == 2 * 3 * 3
def test_nested_repeat2():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(1)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
assert sum([1 for _ in data]) == 3
def test_nested_repeat3():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(2)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
assert sum([1 for _ in data]) == 2 * 3
def test_nested_repeat4():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(1)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
assert sum([1 for _ in data]) == 2 * 3
def test_nested_repeat5():
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(3)
data = data.repeat(2)
data = data.repeat(3)
for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6
def test_nested_repeat6():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.batch(3)
data = data.repeat(3)
for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6
def test_nested_repeat7():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
data = data.batch(3)
for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
assert sum([1 for _ in data]) == 6
def test_nested_repeat8():
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(2, drop_remainder=False)
data = data.repeat(2)
data = data.repeat(3)
for i, d in enumerate(data):
if i % 2 == 0:
assert np.array_equal(d[0], np.asarray([[0], [1]]))
else:
assert np.array_equal(d[0], np.asarray([[2]]))
assert sum([1 for _ in data]) == 6 * 2
def test_nested_repeat9():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat()
data = data.repeat(3)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
if i == 10:
break
def test_nested_repeat10():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(3)
data = data.repeat()
for i, d in enumerate(data):
assert i % 3 == d[0][0]
if i == 10:
break
def test_nested_repeat11():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
data = data.repeat(4)
data = data.repeat(5)
for i, d in enumerate(data):
assert i % 3 == d[0][0]
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
if __name__ == "__main__":
logger.info("--------test tf repeat 01---------")
# test_repeat_01()
......@@ -104,4 +240,3 @@ if __name__ == "__main__":
logger.info("--------test tf repeat 03---------")
test_tf_repeat_03()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册