diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 3f737c167cd1d956870186fa309d8543009c5fca..b3c595870f8ecaddcd9c842efcea9996875bb7e8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -91,11 +91,14 @@ void Sampler::Print(std::ostream &out, bool show_all) const { Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; + TensorRow sample_row; // A call to derived class to get sample ids wrapped inside a buffer RETURN_IF_NOT_OK(GetNextSample(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch - RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); + RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); + sample_ids = sample_row[0]; + // check this buffer is not a ctrl buffer CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); { diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index 7885369c07545c868849c8495034296c572cf7dc..f8f8fe89db5e96e38bc4bfae6194c2e3e860745f 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -125,7 +125,6 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) { EXPECT_TRUE(rc.IsOk()); row_count++; } - MS_LOG(WARNING) <<"row count is: " << row_count; ASSERT_EQ(row_count, 3); // Should be 3 rows fetched }