提交 ad91e6e5 编写于 作者: H hesham

- Fix bug in counting epochs when DeviceQueue is used

上级 92d93ebc
......@@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace mindspore {
......@@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
return Status::OK();
}
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Set total repeats and total epochs for the DeviceQueueOp
node->set_total_repeats(num_epochs_);
node->set_num_repeats_per_epoch(1);
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
op_stack *current_stack = eoe_op_stacks_.top().get();
......
......@@ -92,6 +92,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
/// \brief Set the epoch count for DeviceQueue
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册