mnist_op.h 9.7 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
L
liubuyu 已提交
16 17
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_
Z
zhunaipan 已提交
18 19 20 21 22 23 24 25

#include <memory>
#include <string>
#include <algorithm>
#include <map>
#include <vector>
#include <utility>

L
liubuyu 已提交
26 27 28 29 30 31 32 33 34 35
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
Z
zhunaipan 已提交
36 37 38 39 40 41 42

namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;

43
using MnistLabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;
Z
zhunaipan 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

class MnistOp : public ParallelOp, public RandomAccessOp {
 public:
  class Builder {
   public:
    // Constructor for Builder class of MnistOp
    Builder();

    // Destructor.
    ~Builder() = default;

    // Setter method
    // @param int32_t rows_per_buffer
    // @return Builder setter method returns reference to the builder.
    Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
      builder_rows_per_buffer_ = rows_per_buffer;
      return *this;
    }

    // Setter method
    // @param int32_t op_connector_size
    // @return Builder setter method returns reference to the builder.
    Builder &SetOpConnectorSize(int32_t op_connector_size) {
      builder_op_connector_size_ = op_connector_size;
      return *this;
    }

    // Setter method
    // @param int32_t num_workers
    // @return Builder setter method returns reference to the builder.
    Builder &SetNumWorkers(int32_t num_workers) {
      builder_num_workers_ = num_workers;
      return *this;
    }

    // Setter method
    // @param std::shared_ptr<Sampler> sampler
    // @return Builder setter method returns reference to the builder.
    Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
      builder_sampler_ = std::move(sampler);
      return *this;
    }

    // Setter method
Z
Zirui Wu 已提交
88
    // @param const std::string &dir
Z
zhunaipan 已提交
89 90 91 92 93 94
    // @return
    Builder &SetDir(const std::string &dir) {
      builder_dir_ = dir;
      return *this;
    }

Z
Zirui Wu 已提交
95 96 97 98 99 100 101
    // Setter method
    // @param const std::string &usage
    // @return
    Builder &SetUsage(const std::string &usage) {
      builder_usage_ = usage;
      return *this;
    }
Z
zhunaipan 已提交
102 103 104 105 106 107 108 109 110 111 112
    // Check validity of input args
    // @return - The error code return
    Status SanityCheck();

    // The builder "Build" method creates the final object.
    // @param std::shared_ptr<MnistOp> *op - DatasetOp
    // @return - The error code return
    Status Build(std::shared_ptr<MnistOp> *op);

   private:
    std::string builder_dir_;
Z
Zirui Wu 已提交
113
    std::string builder_usage_;
Z
zhunaipan 已提交
114 115 116 117 118 119 120 121
    int32_t builder_num_workers_;
    int32_t builder_rows_per_buffer_;
    int32_t builder_op_connector_size_;
    std::shared_ptr<Sampler> builder_sampler_;
    std::unique_ptr<DataSchema> builder_schema_;
  };

  // Constructor
Z
Zirui Wu 已提交
122
  // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
Z
zhunaipan 已提交
123 124 125 126 127 128
  // @param int32_t num_workers - number of workers reading images in parallel
  // @param int32_t rows_per_buffer - number of images (rows) in each buffer
  // @param std::string folder_path - dir directory of mnist
  // @param int32_t queue_size - connector queue size
  // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
  // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
Z
Zirui Wu 已提交
129 130
  MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
          int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
Z
zhunaipan 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

  // Destructor.
  ~MnistOp() = default;

  // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
  // @param int32_t worker_id - id of each worker
  // @return Status - The error code return
  Status WorkerEntry(int32_t worker_id) override;

  // Main Loop of MnistOp
  // Master thread: Fill IOBlockQueue, then goes to sleep
  // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
  // @return Status - The error code return
  Status operator()() override;

  // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
147
  // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
Z
zhunaipan 已提交
148 149 150 151 152 153 154 155 156
  // @return Status - The error code return
  Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;

  // A print method typically used for debugging
  // @param out
  // @param show_all
  void Print(std::ostream &out, bool show_all) const override;

  // Function to count the number of samples in the MNIST dataset
J
Jamie Nisbet 已提交
157
  // @param dir path to the MNIST directory
Z
zhunaipan 已提交
158 159
  // @param count output arg that will hold the minimum of the actual dataset size and numSamples
  // @return
Z
Zirui Wu 已提交
160
  static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
Z
zhunaipan 已提交
161

J
Jesse Lee 已提交
162 163 164 165 166 167
  /// \brief Base-class override for NodePass visitor acceptor
  /// \param[in] p Pointer to the NodePass to be accepted
  /// \param[out] modified Indicator if the node was changed at all
  /// \return Status of the node visit
  Status Accept(NodePass *p, bool *modified) override;

168 169 170 171
  // Op name getter
  // @return Name of the current Op
  std::string Name() const override { return "MnistOp"; }

Z
zhunaipan 已提交
172 173 174 175 176 177
 private:
  // Initialize Sampler, calls sampler->Init() within
  // @return Status - The error code return
  Status InitSampler();

  // Load a tensor row according to a pair
178
  // @param row_id_type row_id - id for this tensor row
Z
zhunaipan 已提交
179 180 181
  // @param ImageLabelPair pair - <imagefile,label>
  // @param TensorRow row - image & label read into this tensor row
  // @return Status - The error code return
182
  Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row);
Z
zhunaipan 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241

  // @param const std::vector<int64_t> &keys - keys in ioblock
  // @param std::unique_ptr<DataBuffer> db
  // @return Status - The error code return
  Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);

  // Iterate through all members in sampleIds and fill them into IOBlock.
  // @param std::shared_ptr<Tensor> sample_ids -
  // @param std::vector<int64_t> *keys - keys in ioblock
  // @return Status - The error code return
  Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);

  // Check image file stream.
  // @param const std::string *file_name - image file name
  // @param std::ifstream *image_reader - image file stream
  // @param uint32_t num_images - returns the number of images
  // @return Status - The error code return
  Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);

  // Check label stream.
  // @param const std::string &file_name - label file name
  // @param std::ifstream *label_reader - label file stream
  // @param uint32_t num_labels - returns the number of labels
  // @return Status - The error code return
  Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);

  // Read 4 bytes of data from a file stream.
  // @param std::ifstream *reader - file stream to read
  // @return uint32_t - read out data
  Status ReadFromReader(std::ifstream *reader, uint32_t *result);

  // Swap endian
  // @param uint32_t val -
  // @return uint32_t - swap endian data
  uint32_t SwapEndian(uint32_t val) const;

  // Read the specified number of images and labels from the file stream
  // @param std::ifstream *image_reader - image file stream
  // @param std::ifstream *label_reader - label file stream
  // @param int64_t read_num - number of image to read
  // @return Status - The error code return
  Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);

  // Parse all mnist dataset files
  // @return Status - The error code return
  Status ParseMnistData();

  // Read all files in the directory
  // @return Status - The error code return
  Status WalkAllFiles();

  // Called first when function is called
  // @return Status - The error code return
  Status LaunchThreadsAndInitOp();

  // reset Op
  // @return Status - The error code return
  Status Reset() override;

242 243 244 245
  // Private function for computing the assignment of the column name map.
  // @return - Status
  Status ComputeColMap() override;

Z
zhunaipan 已提交
246 247 248 249 250
  int64_t buf_cnt_;
  int64_t row_cnt_;
  WaitPost wp_;
  std::string folder_path_;  // directory of image folder
  int32_t rows_per_buffer_;
Z
Zirui Wu 已提交
251
  const std::string usage_;  // can only be either "train" or "test"
Z
zhunaipan 已提交
252 253 254 255 256 257 258 259
  std::unique_ptr<DataSchema> data_schema_;
  std::vector<MnistLabelPair> image_label_pairs_;
  std::vector<std::string> image_names_;
  std::vector<std::string> label_names_;
  QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
};
}  // namespace dataset
}  // namespace mindspore
L
liubuyu 已提交
260
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_