提交 34bfa2f7 编写于 作者: J jiangzhiwen

fix skip

上级 9399dffe
......@@ -16,6 +16,7 @@
#include <iostream>
#include <utility>
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
......@@ -26,7 +27,10 @@
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status SkipOp::Builder::SanityCheck() const {
if (build_max_skips_ < 0) {
......@@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<SkipOp>(build_max_skips_);
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
return Status::OK();
}
// Constructor of the SkipOp.
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
// Destructor
SkipOp::~SkipOp() {}
......@@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
}
// Since the buffer may contain multi rows, this function will drop the rows
// that need to skip in it, and then return the buffer.
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Drop first max_skips_ rows
while (skip_count_ < max_skips_) {
if (buf->eoe() || buf->eof()) {
break;
}
// Consider the rows of buffer more than 1
TensorRow drop_row;
int row_num = buf->NumRows();
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
}
if (buf->NumRows() == 0) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
// Handling eoe
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
}
// Handling eof
if (buf->eof()) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
}
*p_buffer = std::move(buf);
return Status::OK();
}
// Base-class override for handling cases when an eoe is received.
Status SkipOp::EoeReceived(int32_t worker_id) {
skip_count_ = 0;
......@@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
return Status::OK();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to
// launch the functor since this op runs inlined inside another operator. The
// function is overloaded to ensure that it is not called by mistake (it will
// generate an error).
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
// main entry point for skip
Status SkipOp::operator()() {
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> curr_buffer;
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
while (curr_buffer->eof() == false) {
// Reset count
skip_count_ = 0;
while (curr_buffer->eoe() == false) {
// Drop first count rows
while (skip_count_ < max_skips_) {
if (curr_buffer->eoe() || curr_buffer->eof()) {
break;
}
// Consider the rows of buffer more than one
TensorRow drop_row;
int row_num = curr_buffer->NumRows();
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row));
}
if (curr_buffer->NumRows() == 0) {
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
// we got eoe, now try again until we got eof
MS_LOG(DEBUG) << "Skip operator EOE Received.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
MS_LOG(DEBUG) << "Skip operator EOF Received.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status SkipOp::EofReceived(int32_t worker_id) {
......
......@@ -42,6 +42,7 @@ class SkipOp : public PipelineOp {
private:
int32_t build_max_skips_;
int32_t builder_op_connector_size_;
Status SanityCheck() const;
};
......@@ -49,7 +50,7 @@ class SkipOp : public PipelineOp {
// Constructor of the SkipOp.
// @note The builder class should be used to call it
// @param count - The number of skips to do
explicit SkipOp(int32_t count);
explicit SkipOp(int32_t count, int32_t op_connector_size);
// Destructor
~SkipOp();
......@@ -60,23 +61,11 @@ class SkipOp : public PipelineOp {
void Print(std::ostream &out, bool show_all) const override;
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
// a buffer from our child.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override;
......
......@@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
ASSERT_TRUE(rc.IsOk());
// SkipOp
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
rc = my_tree->AssociateNode(skip_op);
ASSERT_TRUE(rc.IsOk());
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset.transforms.vision.c_transforms as vision
......@@ -51,7 +50,7 @@ def generator_md():
def test_generator_skip():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
# Here ds1 should be [3, 4]
ds1 = ds1.skip(3)
......@@ -60,6 +59,7 @@ def test_generator_skip():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [3, 4]
def test_skip_1():
......@@ -72,6 +72,7 @@ def test_skip_1():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 0
assert buf == []
def test_skip_2():
......@@ -84,6 +85,7 @@ def test_skip_2():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 5
assert buf == [0, 1, 2, 3, 4]
def test_skip_repeat_1():
......@@ -99,6 +101,7 @@ def test_skip_repeat_1():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 7
assert buf == [3, 4, 0, 1, 2, 3, 4]
def test_skip_repeat_2():
......@@ -114,6 +117,7 @@ def test_skip_repeat_2():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 4
assert buf == [3, 4, 3, 4]
def test_skip_repeat_3():
......@@ -132,6 +136,62 @@ def test_skip_repeat_3():
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 6
assert buf == [3, 4, 3, 4, 3, 4]
def test_skip_take_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3]
ds1 = ds1.take(4)
# Here ds1 should be [2, 3]
ds1 = ds1.skip(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def test_skip_take_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [2, 3, 4]
ds1 = ds1.skip(2)
# Here ds1 should be [2, 3]
ds1 = ds1.take(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def generator_1d():
for i in range(64):
yield (np.array([i]), )
def test_skip_filter_1():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.skip(5)
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
def test_skip_filter_2():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
dataset = dataset.skip(5)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
if __name__ == "__main__":
......@@ -142,3 +202,7 @@ if __name__ == "__main__":
test_skip_repeat_1()
test_skip_repeat_2()
test_skip_repeat_3()
test_skip_take_1()
test_skip_take_2()
test_skip_filter_1()
test_skip_filter_2()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册