提交 c56fe3aa 编写于 作者: M ms_yan

modify take op with an operator

上级 37e35827
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include "common/utils.h" #include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
...@@ -25,7 +26,10 @@ ...@@ -25,7 +26,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Builder constructor. Creates the builder object. // Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status TakeOp::Builder::SanityCheck() const { Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) { if (build_max_takes_ <= 0) {
...@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const { ...@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object. // The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) { Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<TakeOp>(build_max_takes_); *ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
return Status::OK(); return Status::OK();
} }
// Constructor of the TakeOp. // Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
// A print method typically used for debugging // A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const { void TakeOp::Print(std::ostream &out, bool show_all) const {
...@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const { ...@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
} }
} }
// This function will be call muti times to returns the buffer, when meet required max take count or meet // Main entry point for Take
// EOF buffer then this will stop. Status TakeOp::operator()() {
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { TaskManager::FindMe()->Post();
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf; std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); while (buf->eof() == false) {
if (take_count_ == max_takes_) { if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) { // Do drain Operation
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; while (!buf->eoe() && !buf->eof()) {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
} }
} else if (state_ == OpState::kDeOpIdle) { }
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); // Loop until non EOE is received
*p_buffer = std::move(eof_buffer); if (buf->eoe()) {
take_count_ = 0; take_count_ = 0;
} else { RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
MS_LOG(WARNING) << "Invalid OpState: " << state_; RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
continue;
} }
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
}
// Check if the last buf is next eof // Get buffer and push back when take_count is still small
if (buf->eof()) { if (take_count_ < max_takes_) {
*p_buffer = std::move(buf); std::unique_ptr<DataBuffer> p_buffer;
return Status::OK(); RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer)));
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
} }
// Get buffer and push back when take_count is still small take_count_ = 0;
if (take_count_ < max_takes_) { MS_LOG(DEBUG) << "Meet the end and push-back eof buffer.";
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
} RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK(); return Status::OK();
} }
...@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D ...@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK(); return Status::OK();
} }
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
Status TakeOp::PrepareNodePostAction() { Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToRepeatStack(shared_from_this()); tree_->AddToRepeatStack(shared_from_this());
......
...@@ -45,6 +45,7 @@ class TakeOp : public PipelineOp { ...@@ -45,6 +45,7 @@ class TakeOp : public PipelineOp {
private: private:
int32_t build_max_takes_; int32_t build_max_takes_;
int32_t builder_op_connector_size_;
Status SanityCheck() const; Status SanityCheck() const;
}; };
...@@ -52,7 +53,7 @@ class TakeOp : public PipelineOp { ...@@ -52,7 +53,7 @@ class TakeOp : public PipelineOp {
// Constructor of the TakeOp. // Constructor of the TakeOp.
// @note The builder class should be used to call it // @note The builder class should be used to call it
// @param count - The number of takes to do // @param count - The number of takes to do
explicit TakeOp(int32_t count); explicit TakeOp(int32_t count, int32_t op_connector_size);
// Destructor // Destructor
~TakeOp() = default; ~TakeOp() = default;
...@@ -72,23 +73,11 @@ class TakeOp : public PipelineOp { ...@@ -72,23 +73,11 @@ class TakeOp : public PipelineOp {
return out; return out;
} }
// Class functor operator () override. // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// Most dataset ops operate by launching a thread (see ExecutionTree). // provide the master loop that drives the logic for performing the work
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return // @return Status - The error code return
Status operator()() override; Status operator()() override;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// During tree prepare phase, operators may have specific post-operations to perform depending on // During tree prepare phase, operators may have specific post-operations to perform depending on
// their role. // their role.
// @notes Derived versions of this function should always call it's superclass version first // @notes Derived versions of this function should always call it's superclass version first
......
...@@ -30,6 +30,12 @@ def generator_10(): ...@@ -30,6 +30,12 @@ def generator_10():
yield np.array([i]), yield np.array([i]),
def filter_func_ge(data):
if data > 3:
return False
return True
def test_take_01(): def test_take_01():
""" """
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
...@@ -297,6 +303,44 @@ def test_take_16(): ...@@ -297,6 +303,44 @@ def test_take_16():
assert sum([1 for _ in data1]) == 5 assert sum([1 for _ in data1]) == 5
def test_take_17():
"""
Test take: take first, then do fiter operation
"""
logger.info("test_take_17")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_18():
"""
Test take: take first, then do fiter, skip, batch and repeat operation
"""
logger.info("test_take_18")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
data1 = data1.skip(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
if __name__ == '__main__': if __name__ == '__main__':
test_take_01() test_take_01()
test_take_02() test_take_02()
...@@ -314,4 +358,6 @@ if __name__ == '__main__': ...@@ -314,4 +358,6 @@ if __name__ == '__main__':
test_take_14() test_take_14()
test_take_15() test_take_15()
test_take_16() test_take_16()
test_take_17()
test_take_18()
logger.info('== test take operation finished ==') logger.info('== test take operation finished ==')
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册