提交 78c1aa1d 编写于 作者: Z Zirui Wu

Implemented Callback for Dataset

implment pause in MapOp, added more to callback

add ds_callback

- Initial drop of Python DSCallback

- Pybind DSCallback

- Pybind DSCallback

added callback to mapOp

- de_pipeline DSCallback

- de_pipeline DSCallback

add test case, segfault for now

fix seg fault

- de_pipeline DSCallback

remove 1 line

update callback test case, now works

use builder class for mapOp callback

- de_pipeline DSCallback

- de_pipeline DSCallback

- de_pipeline DSCallback

better test case

minor fix

add comments and minor clean ups

get rid of nullptr in MapOp, use other flag instead

fix a bug ParseMapOp only takes 1 callback

- Added WaitedDSCalabck

refactor callback param

fix text case incorrect number

- added testing

fix cpp test case

- added testing

- revert back lenet changes

- cleanup test_callbacks.py

- cleanup test_callbacks.py

fix CI stage I

fix CI stage II

fix CI and update epoch counter

- add validation
- add more testing  test_callbacks.py

use random data op to do tests

adjust when to call EpochBegin/End

- add repeat with callback

- addressing reviewers' comments

- docstring and CI fixes

- docstring and CI fixes

- docstring and CI fixes

- rebase with upstream/master

fix cpp test case

fix review comments

addr review cmts, add test case
上级 89cd4652
......@@ -58,6 +58,7 @@ add_subdirectory(kernels)
add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(text)
add_subdirectory(callback)
######################################################################
add_dependencies(utils core)
add_dependencies(kernels-image core)
......@@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core)
add_dependencies(engine core)
add_dependencies(callback core)
add_dependencies(text core)
add_dependencies(text-kernels core)
add_dependencies(cpp-API core)
......@@ -87,6 +89,7 @@ endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
$<TARGET_OBJECTS:callback>
$<TARGET_OBJECTS:utils>
$<TARGET_OBJECTS:kernels>
$<TARGET_OBJECTS:kernels-image>
......@@ -135,14 +138,14 @@ endif()
target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
if (ENABLE_PYTHON)
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
else()
target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY})
endif()
else()
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
if (ENABLE_PYTHON)
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
else()
target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY})
endif()
......
......@@ -7,6 +7,7 @@ if (ENABLE_PYTHON)
python/bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/callback/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/callback/ds_callback.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
.def(py::init<int32_t>())
.def("set_begin", &PyDSCallback::setBegin)
.def("set_end", &PyDSCallback::setEnd)
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
.def("set_step_begin", &PyDSCallback::setStepBegin)
.def("set_step_end", &PyDSCallback::setStepEnd);
}));
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
.def(py::init<int64_t, int64_t, int64_t>())
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
}));
} // namespace dataset
} // namespace mindspore
......@@ -20,6 +20,7 @@
#include <map>
#include "utils/ms_utils.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/dataset_iterator.h"
......@@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "callbacks") {
std::vector<std::shared_ptr<DSCallback>> callbacks;
std::transform(value.begin(), value.end(), std::back_inserter(callbacks),
[](py::handle cb) { return cb.cast<std::shared_ptr<PyDSCallback>>(); });
(void)map_builder.AddCallbacks(callbacks);
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key);
}
}
}
......
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(callback OBJECT
callback_manager.cc
py_ds_callback.cc
)
else ()
add_library(callback OBJECT
callback_manager.cc
)
endif ()
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
namespace mindspore {
namespace dataset {
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
}
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
RETURN_UNEXPECTED_IF_NULL(op);
op_ = op;
// turn the flag on if callback is set
enabled_ = !callbacks_.empty();
// error check for each of the callbacks
for (auto &cb : callbacks_) {
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
}
return Status::OK();
}
Status CallbackManager::Begin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::End(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
}
return Status::OK();
}
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
}
return Status::OK();
}
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
#include <memory>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// forward declare to avoid cyclic include of dataset_op.h
class DatasetOp;
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
class CallbackManager {
public:
/// CallbackManager default constructor. Init needs to be called before using the created instance.
CallbackManager() : enabled_(false) {}
/// \brief
/// \param [in] callbacks list of callbacks to perform
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
/// \return Status
Status Init(std::shared_ptr<DatasetOp> op);
/// \brief callback function called at the start of the first row
/// \return Status
Status Begin(const CallbackParam &);
/// \brief callback function called at the start of each epoch
/// \return Status
Status EpochBegin(const CallbackParam &);
/// \brief callback function called at the start of each row
/// \return Status
Status StepBegin(const CallbackParam &);
/// \brief callback function called after the last row is processed
/// \return Status
Status End(const CallbackParam &);
/// \brief callback function called at the end of each epoch
/// \return Status
Status EpochEnd(const CallbackParam &);
/// \brief callback function called at the the end of each row
/// \return Status
Status StepEnd(const CallbackParam &);
private:
bool enabled_; // flag to enable callback, if false, all functions would return immediately
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
#include <nlohmann/json.hpp>
namespace mindspore {
namespace dataset {
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
/// This is a prototype for now, more fields will be added
class CallbackParam {
public:
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
// these are constant public fields for easy access and consistency with python cb_param
// the names and orders are consistent with batchInfo
const int64_t cur_epoch_num_; // current epoch
const int64_t cur_epoch_step_num_; // step number of the current epoch
const int64_t cur_step_num_; // step number since the first row
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/callback_param.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class DSCallback {
public:
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
/// \param step_size number of steps to call DSNStepBegin()
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
/// \brief actual callback function for begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for step_end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
/// \brief predicate function, whether begin callback is needed
/// \return bool
virtual bool IsBeginNeeded() = 0;
/// \brief predicate function, whether epoch_begin callback is needed
/// \return bool
virtual bool IsEpochBeginNeeded() = 0;
/// \brief predicate function, whether step_begin callback is needed
/// \return bool
virtual bool IsNStepBeginNeeded() = 0;
/// \brief predicate function, whether end callback is needed
/// \return bool
virtual bool IsEndNeeded() = 0;
/// \brief predicate function, whether epoch_end callback is needed
/// \return bool
virtual bool IsEpochEndNeeded() = 0;
/// \brief predicate function, whether step_end callback is needed
/// \return bool
virtual bool IsNStepEndNeeded() = 0;
/// \brief getter
/// \return step_size
int32_t step_size() const { return step_size_; }
protected:
int32_t step_size_; // step begin/end will be called every step_size_
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
}
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
}
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
}
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
}
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
}
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
{
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
f(cb_param);
}
return Status::OK();
}
void PyDSCallback::setBegin(py::function f) {
begin_func_ = f;
begin_needed_ = true;
}
void PyDSCallback::setEnd(py::function f) {
end_func_ = f;
end_needed_ = true;
}
void PyDSCallback::setEpochBegin(py::function f) {
epoch_begin_func_ = f;
epoch_begin_needed_ = true;
}
void PyDSCallback::setEpochEnd(py::function f) {
epoch_end_func_ = f;
epoch_end_needed_ = true;
}
void PyDSCallback::setStepBegin(py::function f) {
step_begin_func_ = f;
step_begin_needed_ = true;
}
void PyDSCallback::setStepEnd(py::function f) {
step_end_func_ = f;
step_end_needed_ = true;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace dataset {
namespace py = pybind11;
class PyDSCallback : public DSCallback {
public:
/// \brief constructor for PyDSCallback. This callback is for python front end
explicit PyDSCallback(int32_t step_size = 1)
: DSCallback(step_size),
begin_needed_(false),
epoch_begin_needed_(false),
step_begin_needed_(false),
end_needed_(false),
epoch_end_needed_(false),
step_end_needed_(false) {}
void setBegin(py::function f);
void setEnd(py::function f);
void setEpochBegin(py::function f);
void setEpochEnd(py::function f);
void setStepBegin(py::function f);
void setStepEnd(py::function f);
/// \brief actual callback function for begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEpochBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSNStepBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEnd(const CallbackParam &cb_param) override;
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEpochEnd(const CallbackParam &cb_param) override;
/// \brief actual callback function for step_end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSNStepEnd(const CallbackParam &cb_param) override;
/// \brief predicate function, whether begin callback is needed
/// \return bool
bool IsBeginNeeded() override;
/// \brief predicate function, whether epoch_begin callback is needed
/// \return bool
bool IsEpochBeginNeeded() override;
/// \brief predicate function, whether step_begin callback is needed
/// \return bool
bool IsNStepBeginNeeded() override;
/// \brief predicate function, whether end callback is needed
/// \return bool
bool IsEndNeeded() override;
/// \brief predicate function, whether epoch_end callback is needed
/// \return bool
bool IsEpochEndNeeded() override;
/// \brief predicate function, whether step_end callback is needed
/// \return bool
bool IsNStepEndNeeded() override;
/// \brief helper function to acquire GIL then execute a pyfunc
/// \param f the python function
/// \param cb_param
/// \return Status
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
private:
py::function begin_func_;
py::function epoch_begin_func_;
py::function step_begin_func_;
py::function end_func_;
py::function epoch_end_func_;
py::function step_end_func_;
bool begin_needed_;
bool epoch_begin_needed_;
bool step_begin_needed_;
bool end_needed_;
bool epoch_end_needed_;
bool step_end_needed_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
......@@ -21,6 +21,8 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/util/status.h"
......@@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return boolean returns true if it's last iteration
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
/// needed. Only parallelOp needs to override this function.
/// \return Status
virtual Status PauseFromMaster() { return Status::OK(); }
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
......@@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
std::unique_ptr<DbConnector> out_connector_; // Output Connector
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
private:
/// Sets the operator id.
......
......@@ -15,25 +15,23 @@
*/
#include <algorithm>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/callback/callback_param.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
......@@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
RETURN_IF_NOT_OK(sanityCheck());
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_));
return Status::OK();
}
......@@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
Status MapOp::operator()() {
// Create and register the local queues.
local_queues_.Init(num_workers_, oc_queue_size_);
// init callback
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
Status rc = local_queues_.Register(tree_->AllTasks());
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
if (rc.IsError()) {
TaskManager::FindMe()->Post();
return rc;
......@@ -175,28 +177,51 @@ Status MapOp::operator()() {
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(rc);
// num_buffers received, including eoe, num_epoch, num_step of current epoch
int64_t num_buf = 0, ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
int64_t que_id = 0;
std::unique_ptr<DataBuffer> buff;
bool is_eof = false;
// Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues
// Stop when all worker threads are finished (received EOF)
while (!is_eof) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
is_eof = buff->eof();
// Create an empty map worker job to be populated by a databuffer and map jobs
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>();
worker_job->databuffer = std::move(buff);
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
while (!buff->eof()) {
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
while (!buff->eoe()) {
ep_step++;
total_step++;
// Create an empty map worker job to be populated by a databuffer and map jobs
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
// Populate map worker job for a worker to execute
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
// Populate map worker job for a worker to execute
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job)));
que_id = (que_id + 1) % num_workers_;
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// send the eoe buffer to worker
// reset epoch_step when a new epoch is about to start
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0;
}
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
UpdateRepeatAndEpochCounter();
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
// handle eof logic
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
return Status::OK();
}
......@@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
// Fetch next data buffer and map job list
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
// Sanity check the databuffer.
// Special case: if there's more threads than buffers, some threads simply get the final control
// messages (eoe/eof), and so they will not perform the check.
if (!in_buffer->eoe() && !in_buffer->eof()) {
int32_t num_rows = in_buffer->NumRows();
int32_t num_cols = in_buffer->NumCols();
if (num_rows == 0 || num_cols == 0) {
RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer.");
}
}
// Now that init work is done, drop into the main fetching loop.
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
// rather than use the base-class defaults.
while (true) {
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
// with Performance Mode design.
if (in_buffer->eoe()) {
UpdateRepeatAndEpochCounter();
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
// when worker receives the signal from master thread, it increments a atomic int
// the last guy who increments the counter, wakes up master thread
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
// this will block the worker until master thread gives it a new work
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
continue;
} else if (in_buffer->eoe()) {
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK(EoeReceived(worker_id));
// Fetch next data buffer and map job list
......@@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
break;
}
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list));
......@@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
std::vector<TensorRow> result_table;
// Executing the list of jobs
for (size_t i = 0; i < job_list.size(); i++) {
// Executre MapJob.
// Execute MapJob.
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
// Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list.
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
if (i + 1 < job_list.size()) {
job_input_table = std::move(result_table);
}
......@@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<MapOp>(), modified);
}
Status MapOp::PauseFromMaster() {
// reset num_paused workers to 0
num_workers_paused_ = 0;
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
// clear the WaitPost for the next Wait()
master_pause_wp_.Clear();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -16,15 +16,19 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
......@@ -108,6 +112,13 @@ class MapOp : public ParallelOp {
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &AddCallbacks(const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end());
return *this;
}
// The builder "build" method creates the final object.
// @param ptr The shared_ptr to the new MapOp object
// @return Status
......@@ -116,6 +127,7 @@ class MapOp : public ParallelOp {
private:
std::vector<std::string> build_in_col_names_;
std::vector<std::string> build_out_col_names_;
std::vector<std::shared_ptr<DSCallback>> builder_callbacks_;
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
int32_t build_num_workers_;
int32_t build_op_connector_size_;
......@@ -186,6 +198,7 @@ class MapOp : public ParallelOp {
// A unit of job for map worker thread.
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
struct MapWorkerJob {
explicit MapWorkerJob(std::unique_ptr<DataBuffer> db) : databuffer(std::move(db)) {}
std::vector<std::shared_ptr<MapJob>> jobs;
std::unique_ptr<DataBuffer> databuffer;
};
......@@ -215,6 +228,12 @@ class MapOp : public ParallelOp {
// Indices of the columns to process.
std::vector<size_t> to_process_indices_;
// wait post used to perform the pausing logic in MapOp
WaitPost master_pause_wp_;
// count number of workers that have signaled master
std::atomic_int num_workers_paused_;
// Private function for worker/thread to loop continuously. It comprises the main
// logic of MapOp: getting the data from previous Op, validating user specified column names,
// applying a list of TensorOps to each of the data, process the results and then
......@@ -247,6 +266,13 @@ class MapOp : public ParallelOp {
// Private function for initializing private variables such as in_columns_, out_columns_.
// @return - Status
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
// This function should only be called from master thread. It intends to suspend the operation of all workers and
// have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost.
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
// who does the increment wakes up the master.
// @return - Status
Status PauseFromMaster() override;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -34,7 +34,7 @@ class Semaphore {
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
void V();
/// \brief Peek the internal value
/// \return The internal value
......
......@@ -59,6 +59,13 @@ namespace dataset {
} \
} while (false)
#define RETURN_OK_IF_TRUE(_condition) \
do { \
if (_condition) { \
return Status::OK(); \
} \
} while (false)
enum class StatusCode : char {
kOK = 0,
kOutOfMemory = 1,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""init file for python callback"""
from .ds_callback import DSCallback, WaitedDSCallback
__all__ = ["DSCallback", "WaitedDSCallback"]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Python callback class
"""
import threading
from mindspore._c_dataengine import PyDSCallback
from mindspore.train.callback import Callback
from .validators import check_callback
class DSCallback:
"""
Abstract base class used to build a dataset callback class.
Args:
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
Examples:
>>> class PrintInfo(DSCallback):
>>> def ds_epoch_end(self, ds_run_context):
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> data = data.map(operations=op, callbacks=PrintInfo())
"""
@check_callback
def __init__(self, step_size=1):
self.step_size = step_size
def ds_begin(self, ds_run_context):
"""
Called before the data pipeline is started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_epoch_begin(self, ds_run_context):
"""
Called before a new epoch is started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_epoch_end(self, ds_run_context):
"""
Called after an epoch is finished.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_step_begin(self, ds_run_context):
"""
Called before n steps are started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_step_end(self, ds_run_context):
"""
Called after n steps are finished.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def create_runtime_obj(self):
"""
Creates a runtime (C++) object from the callback methods defined by the user.
Returns: _c_dataengine.PyDSCallback
"""
c_cb = PyDSCallback(self.step_size)
at_least_one = False
if self.__class__.ds_begin != DSCallback.ds_begin:
c_cb.set_begin(self.ds_begin)
at_least_one = True
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
c_cb.set_epoch_begin(self.ds_epoch_begin)
at_least_one = True
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
c_cb.set_epoch_end(self.ds_epoch_end)
at_least_one = True
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
c_cb.set_step_begin(self.ds_step_begin)
at_least_one = True
if self.__class__.ds_step_end != DSCallback.ds_step_end:
c_cb.set_step_end(self.ds_step_end)
at_least_one = True
if not at_least_one:
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
return c_cb
class WaitedDSCallback(Callback, DSCallback):
"""
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
This class can be used to execute a user defined logic right after the previous step or epoch.
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
Examples:
>>> my_cb = MyWaitedCallback(32)
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
>>> data = data.batch(32)
>>> # define the model
>>> model.train(epochs, data, callbacks=[my_cb])
Args:
step_size: the number of rows in each step.
Usually the step size will be equal to the batch size (Default=1)
"""
def __init__(self, step_size=1):
super().__init__()
self.step_size = step_size
self.step_event = threading.Event()
self.step_run_context = None
self.epoch_event = threading.Event()
self.epoch_run_context = None
def sync_epoch_begin(self, train_run_context, ds_run_context):
"""
Called before a new dataset epoch is started and after the previous training epoch is ended.
Args:
train_run_context: Include some information of the model with feedback from the previous epoch.
ds_run_context: Include some information of the dataset pipeline.
"""
def sync_step_begin(self, train_run_context, ds_run_context):
"""
Called before a new dataset step is started and after the previous training step is ended.
Args:
train_run_context: Include some information of the model with feedback from the previous step.
ds_run_context: Include some information of the dataset pipeline.
"""
def epoch_end(self, run_context):
"""
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
Args:
run_context: Include some information of the model.
"""
self.epoch_run_context = run_context
self.epoch_event.set()
self.epoch_event.clear()
def ds_epoch_begin(self, ds_run_context):
"""
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
Args:
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_epoch_num > 1:
if self.epoch_run_context is None:
self.epoch_event.wait()
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
self.epoch_run_context = None
def step_end(self, run_context):
"""
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
Args:
run_context: Include some information of the model.
"""
self.step_run_context = run_context
self.step_event.set()
self.step_event.clear()
def ds_step_begin(self, ds_run_context):
"""
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
Args:
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_step_num > self.step_size:
if self.step_run_context is None:
self.step_event.wait()
self.sync_step_begin(self.step_run_context, ds_run_context)
self.step_run_context = None
def create_runtime_obj(self):
"""
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
Returns: _c_dataengine.PyDSCallback
"""
c_cb = PyDSCallback(self.step_size)
at_least_one = False
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
c_cb.set_step_begin(self.ds_step_begin)
at_least_one = True
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
c_cb.set_epoch_begin(self.ds_epoch_begin)
at_least_one = True
if not at_least_one:
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
return c_cb
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License foNtest_resr the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Built-in validators.
"""
from functools import wraps
from ..core.validator_helpers import parse_user_args, check_pos_int32
def check_callback(method):
"""check the input arguments of DSCallback."""
@wraps(method)
def new_method(self, *args, **kwargs):
[step_size], _ = parse_user_args(method, *args, **kwargs)
check_pos_int32(step_size, "step_size")
return method(self, *args, **kwargs)
return new_method
......@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
check_paddeddataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
......@@ -395,7 +395,7 @@ class Dataset:
@check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
"""
Apply each operation in operations to this dataset.
......@@ -438,6 +438,8 @@ class Dataset:
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None).
Returns:
MapDataset, dataset after mapping operation.
......@@ -552,7 +554,7 @@ class Dataset:
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
"""
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing, cache)
python_multiprocessing, cache, callbacks)
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
......@@ -1548,6 +1550,7 @@ class DatasetOp(Dataset):
return self.children[0].get_class_indexing()
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
class BucketBatchByLengthDataset(DatasetOp):
"""
The result of applying BucketBatchByLength operator to the input dataset.
......@@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp):
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None)
Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
"""
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
super().__init__(num_parallel_workers)
self.children.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list):
......@@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp):
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
if callbacks is not None and not isinstance(callbacks, list):
callbacks = [callbacks]
self.callbacks = callbacks
def get_args(self):
args = super().get_args()
args["input_columns"] = self.input_columns
......@@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp):
args["output_columns"] = self.output_columns
args["columns_order"] = self.columns_order
args["cache"] = self.cache.cache_client if self.cache is not None else None
if self.callbacks is not None:
args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks]
return args
def get_dataset_size(self):
......@@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp):
new_op.cache = copy.deepcopy(self.cache, memodict)
new_op.operations = self.operations
new_op.dataset_size = self.dataset_size
new_op.callbacks = self.callbacks
return new_op
# Iterator bootstrap will be called on iterator construction.
......@@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp):
self._children_start_end_index_[index][0] = cumulative_samples_nums
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
tem_sampler = copy.deepcopy(sampler)
tem_sampler.set_offset(cumulative_samples_nums)
child.sampler = tem_sampler
......@@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset):
def get_dataset_size(self):
if self.dataset_size is None:
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
self.dataset_size = math.ceil((self.stop - self.start) / self.step)
return self.dataset_size
......@@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
......@@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset):
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)
class _PaddedDataset:
"""
Mainly for combining false samples provided by users into a dataset.
......@@ -5435,6 +5447,7 @@ class _PaddedDataset:
Args:
padded_samples (list(dict)): the data provided by user to added to initial Dataset
"""
def __init__(self, padded_samples):
self.column_names = list(padded_samples[0].keys())
self.padded_samples = padded_samples
......@@ -5445,6 +5458,7 @@ class _PaddedDataset:
def __len__(self):
return len(self.padded_samples)
class PaddedDataset(GeneratorDataset):
"""
Create a dataset with fake data provided by user. Mainly used to add to the original data set
......@@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset):
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
>>> ds1 = ds.PaddedDataset(data1)
"""
@check_paddeddataset
def __init__(self, padded_samples):
dataset = _PaddedDataset(padded_samples)
......
......@@ -23,6 +23,7 @@ from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from mindspore.dataset.callback import DSCallback
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
......@@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
from . import datasets
from . import samplers
from . import cache_client
from .. import callback
def check_imagefolderdatasetv2(method):
......@@ -247,6 +249,7 @@ def check_celebadataset(method):
return new_method
def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
......@@ -257,7 +260,7 @@ def check_save(method):
nreq_param_int = ['num_files']
nreq_param_str = ['file_name', 'file_type']
validate_dataset_param_value(nreq_param_int, param_dict, int)
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
raise ValueError("num_files should between {} and {}.".format(1, 1000))
validate_dataset_param_value(nreq_param_str, param_dict, str)
if param_dict.get('file_type') != 'mindrecord':
......@@ -265,6 +268,8 @@ def check_save(method):
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
......@@ -362,6 +367,7 @@ def check_generatordataset(method):
return new_method
def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
......@@ -545,7 +551,8 @@ def check_map(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache,
callbacks], _ = \
parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns']
......@@ -558,9 +565,17 @@ def check_map(method):
if cache is not None:
type_check(cache, (cache_client.DatasetCache,), "cache")
if callbacks is not None:
if isinstance(callbacks, (list, tuple)):
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
else:
type_check(callbacks, (callback.DSCallback,), "callbacks")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
if param is not None:
check_columns(param, param_name)
if callbacks is not None:
type_check(callbacks, (list, DSCallback), "callbacks")
return method(self, *args, **kwargs)
......
......@@ -15,6 +15,7 @@ SET(DE_UT_SRCS
bounding_box_augment_op_test.cc
arena_test.cc
btree_test.cc
callback_test.cc
center_crop_op_test.cc
channel_swap_test.cc
circular_pool_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <list>
#include "common/common.h"
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/kernels/data/no_op.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace dataset {
namespace test {
std::shared_ptr<ExecutionTree> BuildTree(std::vector<std::shared_ptr<DatasetOp>> ops) {
std::shared_ptr<ExecutionTree> tree = std::make_shared<ExecutionTree>();
Status rc;
for (int i = 0; i < ops.size(); i++) {
rc = tree->AssociateNode(ops[i]);
EXPECT_TRUE(rc.IsOk());
if (i > 0) {
rc = ops[i]->AddChild(ops[i - 1]);
EXPECT_TRUE(rc.IsOk());
}
if (i == ops.size() - 1) {
rc = tree->AssignRoot(ops[i]);
EXPECT_TRUE(rc.IsOk());
}
}
return tree;
}
class TestCallback : public DSCallback {
public:
TestCallback(int32_t step_size)
: DSCallback(step_size),
begin_(true),
epoch_begin_(true),
step_begin_(true),
end_(true),
epoch_end_(true),
step_end_(true) {
all_names_.reserve(32);
all_step_nums_.reserve(32);
all_ep_nums_.reserve(32);
}
Status DSBegin(const CallbackParam &cb_param) override {
all_names_.push_back("BGN");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
Status DSEpochBegin(const CallbackParam &cb_param) override {
all_names_.push_back("EPBGN");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
Status DSNStepBegin(const CallbackParam &cb_param) override {
all_names_.push_back("SPBGN");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
Status DSEnd(const CallbackParam &cb_param) override {
all_names_.push_back("END");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
Status DSEpochEnd(const CallbackParam &cb_param) override {
all_names_.push_back("EPEND");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
Status DSNStepEnd(const CallbackParam &cb_param) override {
all_names_.push_back("SPEND");
all_step_nums_.push_back(cb_param.cur_step_num_);
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
return Status::OK();
}
bool IsBeginNeeded() override { return begin_; }
bool IsEpochBeginNeeded() override { return epoch_begin_; }
bool IsNStepBeginNeeded() override { return step_begin_; }
bool IsEndNeeded() override { return end_; }
bool IsEpochEndNeeded() override { return epoch_end_; }
bool IsNStepEndNeeded() override { return step_end_; }
std::vector<std::string> all_names(size_t len) {
return std::vector<std::string>(all_names_.begin(), all_names_.begin() + len);
}
std::vector<int64_t> all_step_nums(size_t len) {
return std::vector<int64_t>(all_step_nums_.begin(), all_step_nums_.begin() + len);
}
std::vector<int64_t> all_ep_nums(size_t len) {
return std::vector<int64_t>(all_ep_nums_.begin(), all_ep_nums_.begin() + len);
}
// flag for turning callback on and off
bool begin_, epoch_begin_, step_begin_, end_, epoch_end_, step_end_;
// name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND
std::vector<std::string> all_names_;
std::vector<int64_t> all_step_nums_, all_ep_nums_;
};
} // namespace test
} // namespace dataset
} // namespace mindspore
class MindDataTestCallback : public UT::DatasetOpTesting {
public:
void SetUp() override {
DatasetOpTesting::SetUp();
GlobalInit();
}
};
TEST_F(MindDataTestCallback, TestBasicCallback) {
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false; // don't do the end for now due to a timing issue
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col);
std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf);
EXPECT_TRUE(rc.IsOk());
// config mapOp
std::shared_ptr<MapOp> map_op;
auto map_b = MapOp::Builder();
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
EXPECT_TRUE(rc.IsOk());
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// start build then launch tree
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (!tensor_map.empty()) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
size_t len = 7;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
}
TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false; // don't do the end for now due to a timing issue
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col);
std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
EXPECT_TRUE(rc.IsOk());
// config mapOp
std::shared_ptr<MapOp> map_op;
auto map_b = MapOp::Builder();
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
EXPECT_TRUE(rc.IsOk());
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// start build then launch tree
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
size_t num_epochs = 2;
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND",
"EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::vector<int64_t> all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2};
size_t len = 13;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
}
TEST_F(MindDataTestCallback, TestSelectedCallback) {
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false;
// turn off the epochs
tst_cb->epoch_begin_ = false;
tst_cb->epoch_end_ = false;
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col);
std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
EXPECT_TRUE(rc.IsOk());
// config mapOp
std::shared_ptr<MapOp> map_op;
auto map_b = MapOp::Builder();
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
EXPECT_TRUE(rc.IsOk());
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// start build then launch tree
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
size_t num_epochs = 2;
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (tensor_map.size() != 0) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
}
std::vector<std::string> callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND",
"SPBGN", "SPEND", "SPBGN", "SPEND"};
std::vector<int64_t> all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2};
size_t len = 9;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from builtins import range, super
import time
import pytest
from mindspore import context
from mindspore import log as logger
from mindspore.dataset.callback import DSCallback, WaitedDSCallback
from mindspore.train import Model
from mindspore.train.callback import Callback
import mindspore.dataset as ds
import mindspore.nn as nn
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class MyDSCallback(DSCallback):
def __init__(self, step_size=1, events=None, cb_id=0):
super().__init__(step_size)
self.events = events
self.cb_id = cb_id
def append(self, event_name, ds_run_context):
event = [event_name, ds_run_context.cur_epoch_num,
ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
event = '_'.join([str(e) for e in event])
index = -1
for i, e in enumerate(self.events):
if e[0] == event:
index = i
break
if index != -1:
self.events[index][1].append(self.cb_id)
else:
self.events.append((event, [self.cb_id]))
def ds_begin(self, ds_run_context):
self.append("begin", ds_run_context)
def ds_end(self, ds_run_context):
self.append("end", ds_run_context)
def ds_epoch_begin(self, ds_run_context):
self.append("epoch_begin", ds_run_context)
def ds_epoch_end(self, ds_run_context):
self.append("epoch_end", ds_run_context)
def ds_step_begin(self, ds_run_context):
self.append("step_begin", ds_run_context)
def ds_step_end(self, ds_run_context):
self.append("step_end", ds_run_context)
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
events = []
cb_id = list(range(map_num))
def append(name, e, s):
event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
event = '_'.join([str(ev) for ev in event])
events.append((event, cb_id))
events.append(("begin_0_0_0", cb_id))
for e in range(epoch_num):
append("epoch_begin", e, -1)
for s in range(step_num * repeat):
if s % step_size == 0:
append("step_begin", e, s)
append("step_end", e, s)
append("epoch_end", e, step_num * repeat - 1)
return events
def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
events = []
arr = list(range(1, steps + 1))
data = ds.NumpySlicesDataset(arr, shuffle=False)
my_cb = MyDSCallback(step_size=step_size, events=events)
data = data.map(operations=(lambda x: x), callbacks=my_cb)
if repeat != 1:
data = data.repeat(repeat)
itr = data.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
for _ in itr:
pass
expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
assert expected_events == events
def build_test_case_2cbs(epochs, steps):
events1 = []
events2 = []
my_cb1 = MyDSCallback(events=events1)
my_cb2 = MyDSCallback(events=events2)
arr = list(range(1, steps + 1))
data = ds.NumpySlicesDataset(arr, shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
itr = data.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
for _ in itr:
pass
expected_events = generate_expected(epochs, steps)
assert expected_events == events1
assert expected_events == events2
def build_test_case_2maps(epochs, steps):
events = []
my_cb1 = MyDSCallback(events=events, cb_id=0)
my_cb2 = MyDSCallback(events=events, cb_id=1)
arr = list(range(1, steps + 1))
data = ds.NumpySlicesDataset(arr, shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
data = data.map(operations=(lambda x: x), callbacks=my_cb2)
itr = data.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
for _ in itr:
pass
expected_events = generate_expected(epochs, steps, map_num=2)
assert expected_events[1:] == events[1:]
for event in events:
assert len(event) == 2
event, cb_ids = event
if event != "begin_0_0_0":
assert cb_ids[0] == 0
assert cb_ids[1] == 1
def test_callbacks_all_methods():
logger.info("test_callbacks_all_methods")
build_test_case_1cb(1, 1)
build_test_case_1cb(1, 2)
build_test_case_1cb(1, 3)
build_test_case_1cb(1, 4)
build_test_case_1cb(2, 1)
build_test_case_1cb(2, 2)
build_test_case_1cb(2, 3)
build_test_case_1cb(2, 4)
build_test_case_1cb(3, 1)
build_test_case_1cb(3, 2)
build_test_case_1cb(3, 3)
build_test_case_1cb(3, 4)
def test_callbacks_var_step_size():
logger.info("test_callbacks_var_step_size")
build_test_case_1cb(1, 2, 2)
build_test_case_1cb(1, 3, 2)
build_test_case_1cb(1, 4, 2)
build_test_case_1cb(2, 2, 2)
build_test_case_1cb(2, 3, 2)
build_test_case_1cb(2, 4, 2)
build_test_case_1cb(3, 2, 2)
build_test_case_1cb(3, 3, 2)
build_test_case_1cb(3, 4, 2)
def test_callbacks_all_2cbs():
logger.info("test_callbacks_all_2cbs")
build_test_case_2cbs(4, 1)
build_test_case_2cbs(4, 2)
build_test_case_2cbs(4, 3)
build_test_case_2cbs(4, 4)
def test_callbacks_2maps():
logger.info("test_callbacks_2maps")
build_test_case_2maps(5, 10)
build_test_case_2maps(6, 9)
class MyWaitedCallback(WaitedDSCallback):
def __init__(self, events, step_size=1):
super().__init__(step_size)
self.events = events
def sync_epoch_begin(self, train_run_context, ds_run_context):
event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
self.events.append(event)
def sync_step_begin(self, train_run_context, ds_run_context):
event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
self.events.append(event)
class MyMSCallback(Callback):
def __init__(self, events):
self.events = events
def epoch_end(self, run_context):
cb_params = run_context.original_args()
event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
self.events.append(event)
def step_end(self, run_context):
cb_params = run_context.original_args()
event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
self.events.append(event)
class Net(nn.Cell):
def construct(self, x, y):
return x
def test_train_non_sink():
logger.info("test_train_non_sink")
events = []
my_cb1 = MyWaitedCallback(events, 1)
my_cb2 = MyMSCallback(events)
arr = [1, 2, 3, 4]
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
net = Net()
model = Model(net)
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
'ms_step_end_2_8', 'ms_epoch_end_2_8']
assert events == expected_synced_events
def test_train_batch_size2():
logger.info("test_train_batch_size2")
events = []
my_cb1 = MyWaitedCallback(events, 2)
my_cb2 = MyMSCallback(events)
arr = [1, 2, 3, 4]
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
data = data.batch(2)
net = Net()
model = Model(net)
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
'ms_step_end_1_2',
'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
'ms_step_end_2_4', 'ms_epoch_end_2_4']
assert events == expected_synced_events
def test_callbacks_validations():
logger.info("test_callbacks_validations")
with pytest.raises(Exception) as err:
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data.map(operations=(lambda x: x), callbacks=0)
assert "Argument callbacks with value 0 is not " in str(err.value)
with pytest.raises(Exception) as err:
my_cb1 = MyDSCallback()
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
assert "Argument callbacks[1] with value 0 is not " in str(err.value)
with pytest.raises(Exception) as err:
class BadCB(DSCallback):
pass
my_cb = BadCB()
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb)
for _ in data:
pass
assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
def test_callback_sink_simulation():
logger.info("test_callback_sink_simulation")
events = []
epochs = 2
my_cb = MyWaitedCallback(events, 1)
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb)
data = data.to_device()
data.send(num_epochs=epochs)
for e in range(epochs):
for s in range(4):
time.sleep(0.5)
events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
my_cb.step_end(run_context=0)
events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
my_cb.epoch_end(run_context=0)
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
'ms_step_end_2_8', 'ms_epoch_end_2_8']
assert events == expected_synced_events
def test_callbacks_repeat():
logger.info("test_callbacks_repeat")
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
if __name__ == '__main__':
test_callbacks_all_methods()
test_callbacks_all_2cbs()
test_callbacks_2maps()
test_callbacks_validations()
test_callbacks_var_step_size()
test_train_batch_size2()
test_callback_sink_simulation()
test_callbacks_repeat()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册