diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 678c20ba102f9aa08bcd67ed4d7ad13613ccb7f0..ebb3b334fdace71981494d2a39b7e6c53c82e96e 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -58,6 +58,7 @@ add_subdirectory(kernels) add_subdirectory(engine) add_subdirectory(api) add_subdirectory(text) +add_subdirectory(callback) ###################################################################### add_dependencies(utils core) add_dependencies(kernels-image core) @@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core) add_dependencies(engine-perf core) add_dependencies(engine-gnn core) add_dependencies(engine core) +add_dependencies(callback core) add_dependencies(text core) add_dependencies(text-kernels core) add_dependencies(cpp-API core) @@ -87,6 +89,7 @@ endif () ################### Create _c_dataengine Library ###################### set(submodules $ + $ $ $ $ @@ -135,14 +138,14 @@ endif() target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") if (ENABLE_PYTHON) - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) else() target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY}) endif() else() set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) if (ENABLE_PYTHON) - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) else() target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY}) endif() diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt index 556b91ece03611bebaf7c7898870646b601f6406..d3fd8bc4343c59b9df6da83503171e9e37626a72 100644 --- a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -7,6 +7,7 @@ if (ENABLE_PYTHON) python/bindings.cc python/bindings/dataset/engine/cache/bindings.cc python/bindings/dataset/core/bindings.cc + python/bindings/dataset/callback/bindings.cc python/bindings/dataset/kernels/data/bindings.cc python/bindings/dataset/kernels/bindings.cc python/bindings/dataset/engine/datasetops/bindings.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/callback/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/callback/bindings.cc new file mode 100644 index 0000000000000000000000000000000000000000..2798d8b23da0590fd7a101a8d7a855bab905d531 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/callback/bindings.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" +#include "pybind11/stl_bind.h" + +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/callback/py_ds_callback.h" +#include "minddata/dataset/callback/ds_callback.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) { + (void)py::class_>(*m, "PyDSCallback") + .def(py::init()) + .def("set_begin", &PyDSCallback::setBegin) + .def("set_end", &PyDSCallback::setEnd) + .def("set_epoch_begin", &PyDSCallback::setEpochBegin) + .def("set_epoch_end", &PyDSCallback::setEpochEnd) + .def("set_step_begin", &PyDSCallback::setStepBegin) + .def("set_step_end", &PyDSCallback::setStepEnd); + })); + +PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) { + (void)py::class_>(*m, "CallbackParam") + .def(py::init()) + .def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_) + .def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_) + .def_readonly("cur_step_num", &CallbackParam::cur_step_num_); + })); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index c74ffa73938d6d4bb750000cc0a45ca1c9c8df9c..4b10aa5057ddfd1829785a8581a3bb08a03d8c5e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -20,6 +20,7 @@ #include #include "utils/ms_utils.h" +#include "minddata/dataset/callback/py_ds_callback.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/dataset_iterator.h" @@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); } else if (key == "cache") { cache_client = value.cast>(); + } else if (key == "callbacks") { + std::vector> callbacks; + std::transform(value.begin(), value.end(), std::back_inserter(callbacks), + [](py::handle cb) { return cb.cast>(); }); + (void)map_builder.AddCallbacks(callbacks); } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key); } } } diff --git a/mindspore/ccsrc/minddata/dataset/callback/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/callback/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8bb19f6dd6991fa335b758a2ed856eeb62e74e4f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + + +if (ENABLE_PYTHON) + add_library(callback OBJECT + callback_manager.cc + py_ds_callback.cc + ) +else () + add_library(callback OBJECT + callback_manager.cc + ) +endif () \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..948ae8d45206abc799c17b31f140e52e9897a274 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/callback/callback_manager.h" +#include "minddata/dataset/callback/ds_callback.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { + +void CallbackManager::AddCallbacks(std::vector> callbacks) { + callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); +} + +Status CallbackManager::Init(std::shared_ptr op) { + RETURN_UNEXPECTED_IF_NULL(op); + op_ = op; + // turn the flag on if callback is set + enabled_ = !callbacks_.empty(); + + // error check for each of the callbacks + for (auto &cb : callbacks_) { + CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0."); + } + + return Status::OK(); +} + +Status CallbackManager::Begin(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind); + } + // return Status::OK() if no begin is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param)); + } + return Status::OK(); +} + +Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind); + } + // return Status::OK() if no epoch_begin is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param)); + } + return Status::OK(); +} + +Status CallbackManager::StepBegin(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0) + callback_inds.push_back(ind); + } + // return Status::OK() if no step_begin is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param)); + } + return Status::OK(); +} + +Status CallbackManager::End(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind); + } + // return Status::OK() if no end is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param)); + } + return Status::OK(); +} + +Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind); + } + // return Status::OK() if no epoch_end is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param)); + } + return Status::OK(); +} + +Status CallbackManager::StepEnd(const CallbackParam &cb_param) { + RETURN_OK_IF_TRUE(!enabled_); + std::vector callback_inds; + // go through all callback functions to see if each function is needed + for (size_t ind = 0; ind < callbacks_.size(); ind++) { + if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0) + callback_inds.push_back(ind); + } + // return Status::OK() if no step_end is needed + RETURN_OK_IF_TRUE(callback_inds.empty()); + + RETURN_IF_NOT_OK(op_->PauseFromMaster()); + + // Now do the actual callback + for (size_t ind : callback_inds) { + RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param)); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..73ffda0fd47f30aa8d66663fb2e2a89f4dd7e817 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H + +#include +#include + +#include "minddata/dataset/callback/ds_callback.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// forward declare to avoid cyclic include of dataset_op.h +class DatasetOp; + +/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this. +class CallbackManager { + public: + /// CallbackManager default constructor. Init needs to be called before using the created instance. + CallbackManager() : enabled_(false) {} + + /// \brief + /// \param [in] callbacks list of callbacks to perform + void AddCallbacks(std::vector> callbacks); + + /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true + /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads + /// \return Status + Status Init(std::shared_ptr op); + + /// \brief callback function called at the start of the first row + /// \return Status + Status Begin(const CallbackParam &); + + /// \brief callback function called at the start of each epoch + /// \return Status + Status EpochBegin(const CallbackParam &); + + /// \brief callback function called at the start of each row + /// \return Status + Status StepBegin(const CallbackParam &); + + /// \brief callback function called after the last row is processed + /// \return Status + Status End(const CallbackParam &); + + /// \brief callback function called at the end of each epoch + /// \return Status + Status EpochEnd(const CallbackParam &); + + /// \brief callback function called at the the end of each row + /// \return Status + Status StepEnd(const CallbackParam &); + + private: + bool enabled_; // flag to enable callback, if false, all functions would return immediately + std::shared_ptr op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager + std::vector> callbacks_; // list of callbacks the DatasetOp needs to call +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_param.h b/mindspore/ccsrc/minddata/dataset/callback/callback_param.h new file mode 100644 index 0000000000000000000000000000000000000000..1dfe492002444a6cb5f0326e8a30eee39da8dfc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_param.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H + +#include + +namespace mindspore { +namespace dataset { + +/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function. +/// This is a prototype for now, more fields will be added +class CallbackParam { + public: + CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num) + : cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {} + + // these are constant public fields for easy access and consistency with python cb_param + // the names and orders are consistent with batchInfo + const int64_t cur_epoch_num_; // current epoch + const int64_t cur_epoch_step_num_; // step number of the current epoch + const int64_t cur_step_num_; // step number since the first row +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H diff --git a/mindspore/ccsrc/minddata/dataset/callback/ds_callback.h b/mindspore/ccsrc/minddata/dataset/callback/ds_callback.h new file mode 100644 index 0000000000000000000000000000000000000000..95e36c2b81c5b28a2012ed8abf917e8c8306573f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/ds_callback.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H + +#include +#include +#include + +#include "minddata/dataset/callback/callback_param.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class DSCallback { + public: + /// \brief constructor of DSCallback, this is the base class for all front end specific callbacks + /// \param step_size number of steps to call DSNStepBegin() + explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {} + + /// \brief actual callback function for begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSBegin(const CallbackParam &cb_param) = 0; + + /// \brief actual callback function for epoch_begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0; + + /// \brief actual callback function for step_begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0; + + /// \brief actual callback function for end, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSEnd(const CallbackParam &cb_param) = 0; + + /// \brief actual callback function epoch_end begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0; + + /// \brief actual callback function for step_end, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0; + + /// \brief predicate function, whether begin callback is needed + /// \return bool + virtual bool IsBeginNeeded() = 0; + + /// \brief predicate function, whether epoch_begin callback is needed + /// \return bool + virtual bool IsEpochBeginNeeded() = 0; + + /// \brief predicate function, whether step_begin callback is needed + /// \return bool + virtual bool IsNStepBeginNeeded() = 0; + + /// \brief predicate function, whether end callback is needed + /// \return bool + virtual bool IsEndNeeded() = 0; + + /// \brief predicate function, whether epoch_end callback is needed + /// \return bool + virtual bool IsEpochEndNeeded() = 0; + + /// \brief predicate function, whether step_end callback is needed + /// \return bool + virtual bool IsNStepEndNeeded() = 0; + + /// \brief getter + /// \return step_size + int32_t step_size() const { return step_size_; } + + protected: + int32_t step_size_; // step begin/end will be called every step_size_ +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H diff --git a/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc new file mode 100644 index 0000000000000000000000000000000000000000..56416f72f98514385850bdd7713e6457179913d2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/callback/callback_manager.h" +#include "minddata/dataset/callback/py_ds_callback.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status PyDSCallback::DSBegin(const CallbackParam &cb_param) { + return PyDSCallback::ExecutePyfunc(begin_func_, cb_param); +} +Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) { + return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param); +} +Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) { + return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param); +} +Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); } + +Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) { + return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param); +} +Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) { + return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param); +} + +bool PyDSCallback::IsBeginNeeded() { return begin_needed_; } +bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; } +bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; } +bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; } +bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; } +bool PyDSCallback::IsEndNeeded() { return end_needed_; } + +Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) { + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + f(cb_param); + } + return Status::OK(); +} +void PyDSCallback::setBegin(py::function f) { + begin_func_ = f; + begin_needed_ = true; +} +void PyDSCallback::setEnd(py::function f) { + end_func_ = f; + end_needed_ = true; +} +void PyDSCallback::setEpochBegin(py::function f) { + epoch_begin_func_ = f; + epoch_begin_needed_ = true; +} +void PyDSCallback::setEpochEnd(py::function f) { + epoch_end_func_ = f; + epoch_end_needed_ = true; +} +void PyDSCallback::setStepBegin(py::function f) { + step_begin_func_ = f; + step_begin_needed_ = true; +} +void PyDSCallback::setStepEnd(py::function f) { + step_end_func_ = f; + step_end_needed_ = true; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.h b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.h new file mode 100644 index 0000000000000000000000000000000000000000..644930ddf170c41dee62f8cfbafcc7e50ea00089 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H +#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H + +#include +#include +#include + +#include "minddata/dataset/callback/ds_callback.h" +#include "minddata/dataset/util/status.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace dataset { + +namespace py = pybind11; + +class PyDSCallback : public DSCallback { + public: + /// \brief constructor for PyDSCallback. This callback is for python front end + explicit PyDSCallback(int32_t step_size = 1) + : DSCallback(step_size), + begin_needed_(false), + epoch_begin_needed_(false), + step_begin_needed_(false), + end_needed_(false), + epoch_end_needed_(false), + step_end_needed_(false) {} + + void setBegin(py::function f); + void setEnd(py::function f); + void setEpochBegin(py::function f); + void setEpochEnd(py::function f); + void setStepBegin(py::function f); + void setStepEnd(py::function f); + + /// \brief actual callback function for begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSBegin(const CallbackParam &cb_param) override; + + /// \brief actual callback function for epoch_begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSEpochBegin(const CallbackParam &cb_param) override; + + /// \brief actual callback function for step_begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSNStepBegin(const CallbackParam &cb_param) override; + + /// \brief actual callback function for end, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSEnd(const CallbackParam &cb_param) override; + + /// \brief actual callback function epoch_end begin, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSEpochEnd(const CallbackParam &cb_param) override; + + /// \brief actual callback function for step_end, needs to be overridden in the derived class + /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback + /// \return Status + Status DSNStepEnd(const CallbackParam &cb_param) override; + + /// \brief predicate function, whether begin callback is needed + /// \return bool + bool IsBeginNeeded() override; + + /// \brief predicate function, whether epoch_begin callback is needed + /// \return bool + bool IsEpochBeginNeeded() override; + + /// \brief predicate function, whether step_begin callback is needed + /// \return bool + bool IsNStepBeginNeeded() override; + + /// \brief predicate function, whether end callback is needed + /// \return bool + bool IsEndNeeded() override; + + /// \brief predicate function, whether epoch_end callback is needed + /// \return bool + bool IsEpochEndNeeded() override; + + /// \brief predicate function, whether step_end callback is needed + /// \return bool + bool IsNStepEndNeeded() override; + + /// \brief helper function to acquire GIL then execute a pyfunc + /// \param f the python function + /// \param cb_param + /// \return Status + static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param); + + private: + py::function begin_func_; + py::function epoch_begin_func_; + py::function step_begin_func_; + py::function end_func_; + py::function epoch_end_func_; + py::function step_end_func_; + + bool begin_needed_; + bool epoch_begin_needed_; + bool step_begin_needed_; + bool end_needed_; + bool epoch_end_needed_; + bool step_end_needed_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 3c83582c9f44308030a955327569cbe3a282799e..1e06049c87cac0815cd6c9972e99b8742db04d5b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -21,6 +21,8 @@ #include #include #include + +#include "minddata/dataset/callback/callback_manager.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/util/status.h" @@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this { /// \return boolean returns true if it's last iteration bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } + /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp + /// The expected behavior is this, when this function is invoked, this function will block until all the workers + /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. + /// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not + /// needed. Only parallelOp needs to override this function. + /// \return Status + virtual Status PauseFromMaster() { return Status::OK(); } + protected: /// \brief Removes a parent operator from this operator /// \notes External callers do not have access to this function @@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this { std::unique_ptr out_connector_; // Output Connector std::unordered_map column_name_id_map_; // Mapping between col index and col name std::mutex column_name_map_mutex_; // For protecting shared access to the column map + CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp private: /// Sets the operator id. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index dff595e11e28a31e007c0085b4f344ebe2548ebd..e7061b8ccc07a786a1b59f27eae51de13c571453 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -15,25 +15,23 @@ */ #include #include -#include #include #include #include #include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/callback/callback_param.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/data_buffer.h" -#include "minddata/dataset/engine/db_connector.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" #include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" namespace mindspore { namespace dataset { @@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(sanityCheck()); *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); + (*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_)); return Status::OK(); } @@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr *worker_job) Status MapOp::operator()() { // Create and register the local queues. local_queues_.Init(num_workers_, oc_queue_size_); + // init callback + RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); Status rc = local_queues_.Register(tree_->AllTasks()); + RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks())); if (rc.IsError()) { TaskManager::FindMe()->Post(); return rc; @@ -175,28 +177,51 @@ Status MapOp::operator()() { // Synchronize with TaskManager TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(rc); + // num_buffers received, including eoe, num_epoch, num_step of current epoch + int64_t num_buf = 0, ep_step = 0, total_step = 0; + RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); - int64_t que_id = 0; std::unique_ptr buff; - bool is_eof = false; - // Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues - // Stop when all worker threads are finished (received EOF) - while (!is_eof) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); - is_eof = buff->eof(); - // Create an empty map worker job to be populated by a databuffer and map jobs - std::unique_ptr worker_job = std::make_unique(); - worker_job->databuffer = std::move(buff); + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + while (!buff->eof()) { + if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { + RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + } + while (!buff->eoe()) { + ep_step++; + total_step++; + // Create an empty map worker job to be populated by a databuffer and map jobs + RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + std::unique_ptr worker_job = std::make_unique(std::move(buff)); + + // Populate map worker job for a worker to execute + RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job)); - // Populate map worker job for a worker to execute - RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job)); + // Push map worker job to the corresponding worker's queue + RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); + RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); - // Push map worker job to the corresponding worker's queue - RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job))); - que_id = (que_id + 1) % num_workers_; - } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + } + // send the eoe buffer to worker + + // reset epoch_step when a new epoch is about to start + if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { + RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + ep_step = 0; + } + std::unique_ptr worker_job = std::make_unique(std::move(buff)); + RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); + UpdateRepeatAndEpochCounter(); + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + } + // the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback + // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); + // handle eof logic + std::unique_ptr worker_job = std::make_unique(std::move(buff)); + RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); return Status::OK(); } @@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) { // Fetch next data buffer and map job list RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); - // Sanity check the databuffer. - // Special case: if there's more threads than buffers, some threads simply get the final control - // messages (eoe/eof), and so they will not perform the check. - if (!in_buffer->eoe() && !in_buffer->eof()) { - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); - } - } - // Now that init work is done, drop into the main fetching loop. // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself // rather than use the base-class defaults. while (true) { - // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work - // with Performance Mode design. - if (in_buffer->eoe()) { - UpdateRepeatAndEpochCounter(); + // handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received + if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) { + // when worker receives the signal from master thread, it increments a atomic int + // the last guy who increments the counter, wakes up master thread + if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set(); + // this will block the worker until master thread gives it a new work + RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); + continue; + } else if (in_buffer->eoe()) { // Calling base class EoeReceived to forward eoe buffer. RETURN_IF_NOT_OK(EoeReceived(worker_id)); // Fetch next data buffer and map job list @@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) { break; } + CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer."); std::unique_ptr new_tensor_table(std::make_unique()); // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list)); @@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl std::vector result_table; // Executing the list of jobs for (size_t i = 0; i < job_list.size(); i++) { - // Executre MapJob. + // Execute MapJob. RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table)); - // Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list. + // Assign the processed data as an input for the next job processing, except for the last TensorOp in the list. if (i + 1 < job_list.size()) { job_input_table = std::move(result_table); } @@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +Status MapOp::PauseFromMaster() { + // reset num_paused workers to 0 + num_workers_paused_ = 0; + for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { + // a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause. + RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add( + std::make_unique(std::make_unique(-1, DataBuffer::kDeBFlagNone)))); + } + // wait until all workers are done processing their work in local_queue_ + RETURN_IF_NOT_OK(master_pause_wp_.Wait()); + // clear the WaitPost for the next Wait() + master_pause_wp_.Clear(); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h index 77ee94d86d5e70716ea71f973f5d1a9aa0c3fefb..a66ee55fa9f2aee2a19804930b5047a6735a6bfa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -16,15 +16,19 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ +#include #include #include #include #include #include + +#include "minddata/dataset/callback/ds_callback.h" +#include "minddata/dataset/engine/datasetops/map_op/map_job.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/util/queue.h" -#include "minddata/dataset/engine/datasetops/map_op/map_job.h" +#include "minddata/dataset/util/wait_post.h" namespace mindspore { namespace dataset { @@ -108,6 +112,13 @@ class MapOp : public ParallelOp { return *this; } + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &AddCallbacks(const std::vector> &callbacks) { + builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end()); + return *this; + } + // The builder "build" method creates the final object. // @param ptr The shared_ptr to the new MapOp object // @return Status @@ -116,6 +127,7 @@ class MapOp : public ParallelOp { private: std::vector build_in_col_names_; std::vector build_out_col_names_; + std::vector> builder_callbacks_; std::vector> build_tensor_funcs_; int32_t build_num_workers_; int32_t build_op_connector_size_; @@ -186,6 +198,7 @@ class MapOp : public ParallelOp { // A unit of job for map worker thread. // MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob. struct MapWorkerJob { + explicit MapWorkerJob(std::unique_ptr db) : databuffer(std::move(db)) {} std::vector> jobs; std::unique_ptr databuffer; }; @@ -215,6 +228,12 @@ class MapOp : public ParallelOp { // Indices of the columns to process. std::vector to_process_indices_; + // wait post used to perform the pausing logic in MapOp + WaitPost master_pause_wp_; + + // count number of workers that have signaled master + std::atomic_int num_workers_paused_; + // Private function for worker/thread to loop continuously. It comprises the main // logic of MapOp: getting the data from previous Op, validating user specified column names, // applying a list of TensorOps to each of the data, process the results and then @@ -247,6 +266,13 @@ class MapOp : public ParallelOp { // Private function for initializing private variables such as in_columns_, out_columns_. // @return - Status Status InitPrivateVariable(std::unordered_map *col_name_id_map); + + // This function should only be called from master thread. It intends to suspend the operation of all workers and + // have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost. + // Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker + // who does the increment wakes up the master. + // @return - Status + Status PauseFromMaster() override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h index 88935dc6f7e853436384f43d212e6ce4fbbbb918..e54516291b3019a8953e3dd9a7de86ca424dadfc 100644 --- a/mindspore/ccsrc/minddata/dataset/util/semaphore.h +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -34,7 +34,7 @@ class Semaphore { /// \brief Decrement the internal counter. Will be blocked if the value is 0. /// \return Error code. Can get interrupt. Status P(); - /// \brief Increment the internal counter. Wakeup on of the watiers if any. + /// \brief Increment the internal counter. Wakeup on of the waiters if any. void V(); /// \brief Peek the internal value /// \return The internal value diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index b919b4dc4e05c39dfd0ebc435840594e714f7392..b5500c101390880319c7884b31b7a5ef295ea62d 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -59,6 +59,13 @@ namespace dataset { } \ } while (false) +#define RETURN_OK_IF_TRUE(_condition) \ + do { \ + if (_condition) { \ + return Status::OK(); \ + } \ + } while (false) + enum class StatusCode : char { kOK = 0, kOutOfMemory = 1, diff --git a/mindspore/dataset/callback/__init__.py b/mindspore/dataset/callback/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3b67d57f32d45c5c0800f41ce4421b32e5239e --- /dev/null +++ b/mindspore/dataset/callback/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""init file for python callback""" +from .ds_callback import DSCallback, WaitedDSCallback + +__all__ = ["DSCallback", "WaitedDSCallback"] diff --git a/mindspore/dataset/callback/ds_callback.py b/mindspore/dataset/callback/ds_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b1b0140fdd1225c451203ef6c24ffc762fb91a --- /dev/null +++ b/mindspore/dataset/callback/ds_callback.py @@ -0,0 +1,232 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Python callback class +""" +import threading +from mindspore._c_dataengine import PyDSCallback +from mindspore.train.callback import Callback +from .validators import check_callback + + +class DSCallback: + """ + Abstract base class used to build a dataset callback class. + + Args: + step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1). + + Examples: + >>> class PrintInfo(DSCallback): + >>> def ds_epoch_end(self, ds_run_context): + >>> print(cb_params.cur_epoch_num) + >>> print(cb_params.cur_step_num) + >>> + >>> data = data.map(operations=op, callbacks=PrintInfo()) + """ + + @check_callback + def __init__(self, step_size=1): + self.step_size = step_size + + def ds_begin(self, ds_run_context): + """ + Called before the data pipeline is started. + + Args: + ds_run_context (RunContext): Include some information of the pipeline. + """ + + def ds_epoch_begin(self, ds_run_context): + """ + Called before a new epoch is started. + + Args: + ds_run_context (RunContext): Include some information of the pipeline. + """ + + def ds_epoch_end(self, ds_run_context): + """ + Called after an epoch is finished. + + Args: + ds_run_context (RunContext): Include some information of the pipeline. + """ + + def ds_step_begin(self, ds_run_context): + """ + Called before n steps are started. + + Args: + ds_run_context (RunContext): Include some information of the pipeline. + """ + + def ds_step_end(self, ds_run_context): + """ + Called after n steps are finished. + + Args: + ds_run_context (RunContext): Include some information of the pipeline. + """ + + def create_runtime_obj(self): + """ + Creates a runtime (C++) object from the callback methods defined by the user. + + Returns: _c_dataengine.PyDSCallback + """ + c_cb = PyDSCallback(self.step_size) + at_least_one = False + + if self.__class__.ds_begin != DSCallback.ds_begin: + c_cb.set_begin(self.ds_begin) + at_least_one = True + + if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin: + c_cb.set_epoch_begin(self.ds_epoch_begin) + at_least_one = True + if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end: + c_cb.set_epoch_end(self.ds_epoch_end) + at_least_one = True + + if self.__class__.ds_step_begin != DSCallback.ds_step_begin: + c_cb.set_step_begin(self.ds_step_begin) + at_least_one = True + if self.__class__.ds_step_end != DSCallback.ds_step_end: + c_cb.set_step_end(self.ds_step_end) + at_least_one = True + + if not at_least_one: + raise AttributeError("Provided Callback class did not override any of the 6 callback methods.") + + return c_cb + + +class WaitedDSCallback(Callback, DSCallback): + """ + Abstract base class used to build a dataset callback class that are synchronized with the training callback. + + This class can be used to execute a user defined logic right after the previous step or epoch. + For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters. + + Examples: + >>> my_cb = MyWaitedCallback(32) + >>> data = data.map(operations=AugOp(), callbacks=my_cb) + >>> data = data.batch(32) + >>> # define the model + >>> model.train(epochs, data, callbacks=[my_cb]) + + + Args: + step_size: the number of rows in each step. + Usually the step size will be equal to the batch size (Default=1) + """ + + def __init__(self, step_size=1): + super().__init__() + self.step_size = step_size + self.step_event = threading.Event() + self.step_run_context = None + + self.epoch_event = threading.Event() + self.epoch_run_context = None + + def sync_epoch_begin(self, train_run_context, ds_run_context): + """ + Called before a new dataset epoch is started and after the previous training epoch is ended. + + Args: + train_run_context: Include some information of the model with feedback from the previous epoch. + ds_run_context: Include some information of the dataset pipeline. + """ + + def sync_step_begin(self, train_run_context, ds_run_context): + """ + Called before a new dataset step is started and after the previous training step is ended. + + Args: + train_run_context: Include some information of the model with feedback from the previous step. + ds_run_context: Include some information of the dataset pipeline. + """ + + def epoch_end(self, run_context): + """ + Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin. + + Args: + run_context: Include some information of the model. + """ + self.epoch_run_context = run_context + self.epoch_event.set() + self.epoch_event.clear() + + def ds_epoch_begin(self, ds_run_context): + """ + Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback. + + Args: + ds_run_context: Include some information of the pipeline. + """ + if ds_run_context.cur_epoch_num > 1: + if self.epoch_run_context is None: + self.epoch_event.wait() + self.sync_epoch_begin(self.epoch_run_context, ds_run_context) + self.epoch_run_context = None + + def step_end(self, run_context): + """ + Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin. + + Args: + run_context: Include some information of the model. + """ + self.step_run_context = run_context + self.step_event.set() + self.step_event.clear() + + def ds_step_begin(self, ds_run_context): + """ + Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback. + + Args: + ds_run_context: Include some information of the pipeline. + """ + if ds_run_context.cur_step_num > self.step_size: + if self.step_run_context is None: + self.step_event.wait() + self.sync_step_begin(self.step_run_context, ds_run_context) + self.step_run_context = None + + def create_runtime_obj(self): + """ + Creates a runtime (C++) object from the callback methods defined by the user. This method is internal. + + Returns: _c_dataengine.PyDSCallback + """ + c_cb = PyDSCallback(self.step_size) + at_least_one = False + + if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin: + c_cb.set_step_begin(self.ds_step_begin) + at_least_one = True + + if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin: + c_cb.set_epoch_begin(self.ds_epoch_begin) + at_least_one = True + + if not at_least_one: + raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") + + return c_cb diff --git a/mindspore/dataset/callback/validators.py b/mindspore/dataset/callback/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8ebc1fe256634fc4b913a8fd8d740e585b1ab2 --- /dev/null +++ b/mindspore/dataset/callback/validators.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License foNtest_resr the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Built-in validators. +""" + +from functools import wraps + +from ..core.validator_helpers import parse_user_args, check_pos_int32 + + +def check_callback(method): + """check the input arguments of DSCallback.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [step_size], _ = parse_user_args(method, *args, **kwargs) + check_pos_int32(step_size, "step_size") + return method(self, *args, **kwargs) + + return new_method diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3f1d3e9827c37eb590f439c82c0bd9adbb42ae68..3fa390b5e1efe0a0276b3404099f931dfbdc66ee 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\ + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \ check_paddeddataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE @@ -395,7 +395,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False, cache=None): + num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): """ Apply each operation in operations to this dataset. @@ -438,6 +438,8 @@ class Dataset: option could be beneficial if the python operation is computational heavy (default=False). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). The cache feature is under development and is not recommended. + callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None). + Returns: MapDataset, dataset after mapping operation. @@ -552,7 +554,7 @@ class Dataset: >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, - python_multiprocessing, cache) + python_multiprocessing, cache, callbacks) @check_filter def filter(self, predicate, input_columns=None, num_parallel_workers=1): @@ -1548,6 +1550,7 @@ class DatasetOp(Dataset): return self.children[0].get_class_indexing() raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) + class BucketBatchByLengthDataset(DatasetOp): """ The result of applying BucketBatchByLength operator to the input dataset. @@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp): option could be beneficial if the python operation is computational heavy (default=False). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). The cache feature is under development and is not recommended. - + callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None) Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False, cache=None): + num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): super().__init__(num_parallel_workers) self.children.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp): self.python_multiprocessing = python_multiprocessing self.process_pool = None + if callbacks is not None and not isinstance(callbacks, list): + callbacks = [callbacks] + + self.callbacks = callbacks + def get_args(self): args = super().get_args() args["input_columns"] = self.input_columns @@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp): args["output_columns"] = self.output_columns args["columns_order"] = self.columns_order args["cache"] = self.cache.cache_client if self.cache is not None else None + + if self.callbacks is not None: + args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks] return args def get_dataset_size(self): @@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp): new_op.cache = copy.deepcopy(self.cache, memodict) new_op.operations = self.operations new_op.dataset_size = self.dataset_size + new_op.callbacks = self.callbacks return new_op # Iterator bootstrap will be called on iterator construction. @@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp): self._children_start_end_index_[index][0] = cumulative_samples_nums self._children_start_end_index_[index][1] = tem_value % sampler.num_shards - tem_sampler = copy.deepcopy(sampler) tem_sampler.set_offset(cumulative_samples_nums) child.sampler = tem_sampler @@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset): def get_dataset_size(self): if self.dataset_size is None: - self.dataset_size = math.ceil((self.stop - self.start)/self.step) + self.dataset_size = math.ceil((self.stop - self.start) / self.step) return self.dataset_size @@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset): if not self.num_shards: self.dataset_size = len(self.source) else: - self.dataset_size = math.ceil(len(self.source)/self.num_shards) + self.dataset_size = math.ceil(len(self.source) / self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: @@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset): num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id) + class _PaddedDataset: """ Mainly for combining false samples provided by users into a dataset. @@ -5435,6 +5447,7 @@ class _PaddedDataset: Args: padded_samples (list(dict)): the data provided by user to added to initial Dataset """ + def __init__(self, padded_samples): self.column_names = list(padded_samples[0].keys()) self.padded_samples = padded_samples @@ -5445,6 +5458,7 @@ class _PaddedDataset: def __len__(self): return len(self.padded_samples) + class PaddedDataset(GeneratorDataset): """ Create a dataset with fake data provided by user. Mainly used to add to the original data set @@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset): >>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}] >>> ds1 = ds.PaddedDataset(data1) """ + @check_paddeddataset def __init__(self, padded_samples): dataset = _PaddedDataset(padded_samples) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index dbe4ff6212bfe80885e9394c75489e407217489b..1897828e90953af35f3f1ad9ca60ec60e3cef88a 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -23,6 +23,7 @@ from functools import wraps import numpy as np from mindspore._c_expression import typing +from mindspore.dataset.callback import DSCallback from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ @@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis from . import datasets from . import samplers from . import cache_client +from .. import callback def check_imagefolderdatasetv2(method): @@ -247,6 +249,7 @@ def check_celebadataset(method): return new_method + def check_save(method): """A wrapper that wrap a parameter checker to the save op.""" @@ -257,7 +260,7 @@ def check_save(method): nreq_param_int = ['num_files'] nreq_param_str = ['file_name', 'file_type'] validate_dataset_param_value(nreq_param_int, param_dict, int) - if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): + if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): raise ValueError("num_files should between {} and {}.".format(1, 1000)) validate_dataset_param_value(nreq_param_str, param_dict, str) if param_dict.get('file_type') != 'mindrecord': @@ -265,6 +268,8 @@ def check_save(method): return method(self, *args, **kwargs) return new_method + + def check_minddataset(method): """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" @@ -362,6 +367,7 @@ def check_generatordataset(method): return new_method + def check_random_dataset(method): """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" @@ -545,7 +551,8 @@ def check_map(method): @wraps(method) def new_method(self, *args, **kwargs): - [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \ + [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache, + callbacks], _ = \ parse_user_args(method, *args, **kwargs) nreq_param_columns = ['input_columns', 'output_columns'] @@ -558,9 +565,17 @@ def check_map(method): if cache is not None: type_check(cache, (cache_client.DatasetCache,), "cache") + if callbacks is not None: + if isinstance(callbacks, (list, tuple)): + type_check_list(callbacks, (callback.DSCallback,), "callbacks") + else: + type_check(callbacks, (callback.DSCallback,), "callbacks") + for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): if param is not None: check_columns(param, param_name) + if callbacks is not None: + type_check(callbacks, (list, DSCallback), "callbacks") return method(self, *args, **kwargs) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 6b5f08af99b0881daa91722374fe232cc738b002..8681f830c6d568c4513bcda6ee3d00d98d5692a3 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -15,6 +15,7 @@ SET(DE_UT_SRCS bounding_box_augment_op_test.cc arena_test.cc btree_test.cc + callback_test.cc center_crop_op_test.cc channel_swap_test.cc circular_pool_test.cc diff --git a/tests/ut/cpp/dataset/callback_test.cc b/tests/ut/cpp/dataset/callback_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39671b38047fd6c94682cc7ae9713edb34bc655e --- /dev/null +++ b/tests/ut/cpp/dataset/callback_test.cc @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/common.h" +#include "minddata/dataset/callback/ds_callback.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/kernels/data/no_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace dataset { +namespace test { + +std::shared_ptr BuildTree(std::vector> ops) { + std::shared_ptr tree = std::make_shared(); + Status rc; + for (int i = 0; i < ops.size(); i++) { + rc = tree->AssociateNode(ops[i]); + EXPECT_TRUE(rc.IsOk()); + if (i > 0) { + rc = ops[i]->AddChild(ops[i - 1]); + EXPECT_TRUE(rc.IsOk()); + } + if (i == ops.size() - 1) { + rc = tree->AssignRoot(ops[i]); + EXPECT_TRUE(rc.IsOk()); + } + } + return tree; +} + +class TestCallback : public DSCallback { + public: + TestCallback(int32_t step_size) + : DSCallback(step_size), + begin_(true), + epoch_begin_(true), + step_begin_(true), + end_(true), + epoch_end_(true), + step_end_(true) { + all_names_.reserve(32); + all_step_nums_.reserve(32); + all_ep_nums_.reserve(32); + } + + Status DSBegin(const CallbackParam &cb_param) override { + all_names_.push_back("BGN"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + Status DSEpochBegin(const CallbackParam &cb_param) override { + all_names_.push_back("EPBGN"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + Status DSNStepBegin(const CallbackParam &cb_param) override { + all_names_.push_back("SPBGN"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + Status DSEnd(const CallbackParam &cb_param) override { + all_names_.push_back("END"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + Status DSEpochEnd(const CallbackParam &cb_param) override { + all_names_.push_back("EPEND"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + Status DSNStepEnd(const CallbackParam &cb_param) override { + all_names_.push_back("SPEND"); + all_step_nums_.push_back(cb_param.cur_step_num_); + all_ep_nums_.push_back(cb_param.cur_epoch_num_); + return Status::OK(); + } + + bool IsBeginNeeded() override { return begin_; } + bool IsEpochBeginNeeded() override { return epoch_begin_; } + bool IsNStepBeginNeeded() override { return step_begin_; } + bool IsEndNeeded() override { return end_; } + bool IsEpochEndNeeded() override { return epoch_end_; } + bool IsNStepEndNeeded() override { return step_end_; } + + std::vector all_names(size_t len) { + return std::vector(all_names_.begin(), all_names_.begin() + len); + } + + std::vector all_step_nums(size_t len) { + return std::vector(all_step_nums_.begin(), all_step_nums_.begin() + len); + } + + std::vector all_ep_nums(size_t len) { + return std::vector(all_ep_nums_.begin(), all_ep_nums_.begin() + len); + } + + // flag for turning callback on and off + bool begin_, epoch_begin_, step_begin_, end_, epoch_end_, step_end_; + // name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND + std::vector all_names_; + std::vector all_step_nums_, all_ep_nums_; +}; + +} // namespace test +} // namespace dataset +} // namespace mindspore + +class MindDataTestCallback : public UT::DatasetOpTesting { + public: + void SetUp() override { + DatasetOpTesting::SetUp(); + GlobalInit(); + } +}; + +TEST_F(MindDataTestCallback, TestBasicCallback) { + // config callback + Status rc; + std::shared_ptr tst_cb = std::make_shared(64); + std::shared_ptr cb1 = tst_cb; + tst_cb->end_ = false; // don't do the end for now due to a timing issue + // config leaf_op, use random_data to avoid I/O + std::unique_ptr schema = std::make_unique(); + TensorShape shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); + schema->AddColumn(col); + std::shared_ptr leaf; + rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf); + EXPECT_TRUE(rc.IsOk()); + // config mapOp + std::shared_ptr map_op; + auto map_b = MapOp::Builder(); + rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared()}).AddCallbacks({cb1}).Build(&map_op); + EXPECT_TRUE(rc.IsOk()); + // config RepeatOp + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(2).Build(&repeat_op); + // start build then launch tree + std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); + rc = tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = tree->Launch(); + EXPECT_TRUE(rc.IsOk()); + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorMap tensor_map; + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + while (!tensor_map.empty()) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + } + + std::vector callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"}; + std::vector all_steps = {0, 0, 1, 1, 65, 65, 88}; + std::vector all_epochs = {0, 1, 1, 1, 1, 1, 1}; + // doing resize to make sure no unexpected epoch_end or extra epoch_begin is called + size_t len = 7; + EXPECT_EQ(tst_cb->all_names(len), callback_names); + EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); + EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); +} + +TEST_F(MindDataTestCallback, TestMutiEpochCallback) { + // config callback + Status rc; + std::shared_ptr tst_cb = std::make_shared(4); + std::shared_ptr cb1 = tst_cb; + tst_cb->end_ = false; // don't do the end for now due to a timing issue + // config leaf_op, use random_data to avoid I/O + std::unique_ptr schema = std::make_unique(); + TensorShape shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); + schema->AddColumn(col); + std::shared_ptr leaf; + rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); + EXPECT_TRUE(rc.IsOk()); + // config mapOp + std::shared_ptr map_op; + auto map_b = MapOp::Builder(); + rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared()}).AddCallbacks({cb1}).Build(&map_op); + EXPECT_TRUE(rc.IsOk()); + // config RepeatOp + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(2).Build(&repeat_op); + // start build then launch tree + std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); + rc = tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = tree->Launch(); + EXPECT_TRUE(rc.IsOk()); + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorMap tensor_map; + size_t num_epochs = 2; + for (int ep_num = 0; ep_num < num_epochs; ++ep_num) { + di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + + while (tensor_map.size() != 0) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + } + } + + std::vector callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND", + "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"}; + + std::vector all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16}; + std::vector all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2}; + + size_t len = 13; + EXPECT_EQ(tst_cb->all_names(len), callback_names); + EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); + EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); +} + +TEST_F(MindDataTestCallback, TestSelectedCallback) { + // config callback + Status rc; + std::shared_ptr tst_cb = std::make_shared(4); + std::shared_ptr cb1 = tst_cb; + tst_cb->end_ = false; + // turn off the epochs + tst_cb->epoch_begin_ = false; + tst_cb->epoch_end_ = false; + + // config leaf_op, use random_data to avoid I/O + std::unique_ptr schema = std::make_unique(); + TensorShape shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); + schema->AddColumn(col); + std::shared_ptr leaf; + rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); + EXPECT_TRUE(rc.IsOk()); + // config mapOp + std::shared_ptr map_op; + auto map_b = MapOp::Builder(); + rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared()}).AddCallbacks({cb1}).Build(&map_op); + EXPECT_TRUE(rc.IsOk()); + // config RepeatOp + std::shared_ptr repeat_op; + rc = RepeatOp::Builder(2).Build(&repeat_op); + // start build then launch tree + std::shared_ptr tree = test::BuildTree({leaf, map_op, repeat_op}); + rc = tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = tree->Launch(); + EXPECT_TRUE(rc.IsOk()); + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorMap tensor_map; + size_t num_epochs = 2; + for (int ep_num = 0; ep_num < num_epochs; ++ep_num) { + di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + + while (tensor_map.size() != 0) { + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(rc.IsOk()); + } + } + + std::vector callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND", + "SPBGN", "SPEND", "SPBGN", "SPEND"}; + + std::vector all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13}; + std::vector all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2}; + + size_t len = 9; + EXPECT_EQ(tst_cb->all_names(len), callback_names); + EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); + EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); +} diff --git a/tests/ut/python/dataset/test_callbacks.py b/tests/ut/python/dataset/test_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..2c56dd1bcbb4a9cd924c47f3a1afc91129b34491 --- /dev/null +++ b/tests/ut/python/dataset/test_callbacks.py @@ -0,0 +1,365 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from builtins import range, super +import time + +import pytest + +from mindspore import context +from mindspore import log as logger +from mindspore.dataset.callback import DSCallback, WaitedDSCallback +from mindspore.train import Model +from mindspore.train.callback import Callback + +import mindspore.dataset as ds +import mindspore.nn as nn + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class MyDSCallback(DSCallback): + def __init__(self, step_size=1, events=None, cb_id=0): + super().__init__(step_size) + self.events = events + self.cb_id = cb_id + + def append(self, event_name, ds_run_context): + event = [event_name, ds_run_context.cur_epoch_num, + ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num] + event = '_'.join([str(e) for e in event]) + index = -1 + for i, e in enumerate(self.events): + if e[0] == event: + index = i + break + if index != -1: + self.events[index][1].append(self.cb_id) + else: + self.events.append((event, [self.cb_id])) + + def ds_begin(self, ds_run_context): + self.append("begin", ds_run_context) + + def ds_end(self, ds_run_context): + self.append("end", ds_run_context) + + def ds_epoch_begin(self, ds_run_context): + self.append("epoch_begin", ds_run_context) + + def ds_epoch_end(self, ds_run_context): + self.append("epoch_end", ds_run_context) + + def ds_step_begin(self, ds_run_context): + self.append("step_begin", ds_run_context) + + def ds_step_end(self, ds_run_context): + self.append("step_end", ds_run_context) + + +def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): + events = [] + cb_id = list(range(map_num)) + + def append(name, e, s): + event = [name, e + 1, s + 1, e * step_num * repeat + s + 1] + event = '_'.join([str(ev) for ev in event]) + events.append((event, cb_id)) + + events.append(("begin_0_0_0", cb_id)) + for e in range(epoch_num): + append("epoch_begin", e, -1) + for s in range(step_num * repeat): + if s % step_size == 0: + append("step_begin", e, s) + append("step_end", e, s) + append("epoch_end", e, step_num * repeat - 1) + return events + + +def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): + events = [] + + arr = list(range(1, steps + 1)) + data = ds.NumpySlicesDataset(arr, shuffle=False) + + my_cb = MyDSCallback(step_size=step_size, events=events) + + data = data.map(operations=(lambda x: x), callbacks=my_cb) + if repeat != 1: + data = data.repeat(repeat) + itr = data.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + for _ in itr: + pass + + expected_events = generate_expected(epochs, steps, step_size, 1, repeat) + assert expected_events == events + + +def build_test_case_2cbs(epochs, steps): + events1 = [] + events2 = [] + my_cb1 = MyDSCallback(events=events1) + my_cb2 = MyDSCallback(events=events2) + + arr = list(range(1, steps + 1)) + data = ds.NumpySlicesDataset(arr, shuffle=False) + + data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2]) + + itr = data.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + for _ in itr: + pass + + expected_events = generate_expected(epochs, steps) + assert expected_events == events1 + assert expected_events == events2 + + +def build_test_case_2maps(epochs, steps): + events = [] + my_cb1 = MyDSCallback(events=events, cb_id=0) + my_cb2 = MyDSCallback(events=events, cb_id=1) + + arr = list(range(1, steps + 1)) + data = ds.NumpySlicesDataset(arr, shuffle=False) + + data = data.map(operations=(lambda x: x), callbacks=my_cb1) + data = data.map(operations=(lambda x: x), callbacks=my_cb2) + + itr = data.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + for _ in itr: + pass + + expected_events = generate_expected(epochs, steps, map_num=2) + + assert expected_events[1:] == events[1:] + + for event in events: + assert len(event) == 2 + event, cb_ids = event + if event != "begin_0_0_0": + assert cb_ids[0] == 0 + assert cb_ids[1] == 1 + + +def test_callbacks_all_methods(): + logger.info("test_callbacks_all_methods") + + build_test_case_1cb(1, 1) + build_test_case_1cb(1, 2) + build_test_case_1cb(1, 3) + build_test_case_1cb(1, 4) + + build_test_case_1cb(2, 1) + build_test_case_1cb(2, 2) + build_test_case_1cb(2, 3) + build_test_case_1cb(2, 4) + + build_test_case_1cb(3, 1) + build_test_case_1cb(3, 2) + build_test_case_1cb(3, 3) + build_test_case_1cb(3, 4) + + +def test_callbacks_var_step_size(): + logger.info("test_callbacks_var_step_size") + + build_test_case_1cb(1, 2, 2) + build_test_case_1cb(1, 3, 2) + build_test_case_1cb(1, 4, 2) + + build_test_case_1cb(2, 2, 2) + build_test_case_1cb(2, 3, 2) + build_test_case_1cb(2, 4, 2) + + build_test_case_1cb(3, 2, 2) + build_test_case_1cb(3, 3, 2) + build_test_case_1cb(3, 4, 2) + + +def test_callbacks_all_2cbs(): + logger.info("test_callbacks_all_2cbs") + + build_test_case_2cbs(4, 1) + build_test_case_2cbs(4, 2) + build_test_case_2cbs(4, 3) + build_test_case_2cbs(4, 4) + + +def test_callbacks_2maps(): + logger.info("test_callbacks_2maps") + + build_test_case_2maps(5, 10) + + build_test_case_2maps(6, 9) + + +class MyWaitedCallback(WaitedDSCallback): + def __init__(self, events, step_size=1): + super().__init__(step_size) + self.events = events + + def sync_epoch_begin(self, train_run_context, ds_run_context): + event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" + self.events.append(event) + + def sync_step_begin(self, train_run_context, ds_run_context): + event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" + self.events.append(event) + + +class MyMSCallback(Callback): + def __init__(self, events): + self.events = events + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" + self.events.append(event) + + def step_end(self, run_context): + cb_params = run_context.original_args() + event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" + self.events.append(event) + + +class Net(nn.Cell): + def construct(self, x, y): + return x + + +def test_train_non_sink(): + logger.info("test_train_non_sink") + + events = [] + my_cb1 = MyWaitedCallback(events, 1) + my_cb2 = MyMSCallback(events) + arr = [1, 2, 3, 4] + data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=my_cb1) + + net = Net() + model = Model(net) + + model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) + expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', + 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', + 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', + 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', + 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', + 'ms_step_end_2_8', 'ms_epoch_end_2_8'] + + assert events == expected_synced_events + + +def test_train_batch_size2(): + logger.info("test_train_batch_size2") + + events = [] + my_cb1 = MyWaitedCallback(events, 2) + my_cb2 = MyMSCallback(events) + arr = [1, 2, 3, 4] + data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=my_cb1) + data = data.batch(2) + net = Net() + model = Model(net) + + model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) + + expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3', + 'ms_step_end_1_2', + 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4', + 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7', + 'ms_step_end_2_4', 'ms_epoch_end_2_4'] + + assert events == expected_synced_events + + +def test_callbacks_validations(): + logger.info("test_callbacks_validations") + + with pytest.raises(Exception) as err: + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data.map(operations=(lambda x: x), callbacks=0) + assert "Argument callbacks with value 0 is not " in str(err.value) + + with pytest.raises(Exception) as err: + my_cb1 = MyDSCallback() + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data.map(operations=(lambda x: x), callbacks=[my_cb1, 0]) + assert "Argument callbacks[1] with value 0 is not " in str(err.value) + + with pytest.raises(Exception) as err: + class BadCB(DSCallback): + pass + + my_cb = BadCB() + + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=my_cb) + for _ in data: + pass + assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value) + + +def test_callback_sink_simulation(): + logger.info("test_callback_sink_simulation") + + events = [] + epochs = 2 + my_cb = MyWaitedCallback(events, 1) + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=my_cb) + data = data.to_device() + data.send(num_epochs=epochs) + for e in range(epochs): + for s in range(4): + time.sleep(0.5) + events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}") + my_cb.step_end(run_context=0) + events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}") + my_cb.epoch_end(run_context=0) + expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', + 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', + 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', + 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', + 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', + 'ms_step_end_2_8', 'ms_epoch_end_2_8'] + + assert events == expected_synced_events + + +def test_callbacks_repeat(): + logger.info("test_callbacks_repeat") + + build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) + build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3) + build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) + build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) + + +if __name__ == '__main__': + test_callbacks_all_methods() + test_callbacks_all_2cbs() + test_callbacks_2maps() + test_callbacks_validations() + test_callbacks_var_step_size() + test_train_batch_size2() + test_callback_sink_simulation() + test_callbacks_repeat()