chunk.cc 4.8 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/recordio/chunk.h"

17
#include <zlib.h>
Y
Yi Wang 已提交
18
#include <algorithm>
Y
Yu Yang 已提交
19
#include <memory>
D
dongzhihong 已提交
20
#include <sstream>
Y
Yi Wang 已提交
21

Y
Yu Yang 已提交
22
#include "paddle/fluid/platform/enforce.h"
23
#include "snappystream.hpp"
D
dongzhihong 已提交
24 25 26

namespace paddle {
namespace recordio {
Y
Yu Yang 已提交
27
constexpr size_t kMaxBufSize = 1024;
D
dongzhihong 已提交
28

Y
Yu Yang 已提交
29 30 31 32 33 34 35
/**
 * Read Stream by a fixed sized buffer.
 * @param in input stream
 * @param limit read at most `limit` bytes from input stream. 0 means no limit
 * @param callback A function object with (const char* buf, size_t size) -> void
 * as its type.
 */
Y
Yu Yang 已提交
36
template <typename Callback>
Y
Yu Yang 已提交
37
static void ReadStreamByBuf(std::istream& in, size_t limit, Callback callback) {
Y
Yu Yang 已提交
38 39 40
  char buf[kMaxBufSize];
  std::streamsize actual_size;
  size_t counter = 0;
Y
Yu Yang 已提交
41
  size_t actual_max;
Y
Yu Yang 已提交
42 43
  while (!in.eof() ||
         (limit != 0 && counter >= limit)) {  // End of file or reach limit
Y
Yu Yang 已提交
44 45 46 47
    actual_max =
        limit != 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize;
    in.read(buf, actual_max);
    actual_size = in.gcount();
Y
Yu Yang 已提交
48 49 50 51
    if (actual_size == 0) {
      break;
    }
    callback(buf, actual_size);
Y
Yu Yang 已提交
52
    if (limit != 0) {
Y
Yu Yang 已提交
53 54
      counter += actual_size;
    }
Y
Yu Yang 已提交
55 56
  }
  in.clear();  // unset eof state
D
dongzhihong 已提交
57 58
}

Y
Yu Yang 已提交
59 60 61
/**
 * Copy stream in to another stream
 */
Y
Yu Yang 已提交
62
static void PipeStream(std::istream& in, std::ostream& os) {
Y
Yi Wang 已提交
63 64
  ReadStreamByBuf(in, 0,
                  [&os](const char* buf, size_t len) { os.write(buf, len); });
Y
Yu Yang 已提交
65
}
Y
Yu Yang 已提交
66 67 68 69

/**
 * Calculate CRC32 from an input stream.
 */
Y
Yu Yang 已提交
70 71
static uint32_t Crc32Stream(std::istream& in, size_t limit = 0) {
  uint32_t crc = static_cast<uint32_t>(crc32(0, nullptr, 0));
Y
Yu Yang 已提交
72
  ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) {
Y
Yi Wang 已提交
73 74
    crc = static_cast<uint32_t>(crc32(crc, reinterpret_cast<const Bytef*>(buf),
                                      static_cast<uInt>(len)));
Y
Yu Yang 已提交
75 76 77 78 79
  });
  return crc;
}

bool Chunk::Write(std::ostream& os, Compressor ct) const {
D
dongzhihong 已提交
80 81
  // NOTE(dzhwinter): don't check records.numBytes instead, because
  // empty records are allowed.
Y
Yu Yang 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
  if (records_.empty()) {
    return false;
  }
  std::stringstream sout;
  std::unique_ptr<std::ostream> compressed_stream;
  switch (ct) {
    case Compressor::kNoCompress:
      break;
    case Compressor::kSnappy:
      compressed_stream.reset(new snappy::oSnappyStream(sout));
      break;
    default:
      PADDLE_THROW("Not implemented");
  }

  std::ostream& buf_stream = compressed_stream ? *compressed_stream : sout;
D
dongzhihong 已提交
98 99

  for (auto& record : records_) {
Y
Yu Yang 已提交
100 101 102
    size_t sz = record.size();
    buf_stream.write(reinterpret_cast<const char*>(&sz), sizeof(uint32_t))
        .write(record.data(), record.size());
D
dongzhihong 已提交
103 104
  }

Y
Yu Yang 已提交
105 106 107 108
  if (compressed_stream) {
    compressed_stream.reset();
  }

Y
Yu Yang 已提交
109 110 111
  sout.seekg(0, std::ios::end);
  uint32_t len = static_cast<uint32_t>(sout.tellg());
  sout.seekg(0, std::ios::beg);
Y
Yu Yang 已提交
112 113 114
  uint32_t crc = Crc32Stream(sout);
  Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len);
  hdr.Write(os);
Y
Yu Yang 已提交
115 116
  sout.seekg(0, std::ios::beg);
  sout.clear();
Y
Yu Yang 已提交
117
  PipeStream(sout, os);
D
dongzhihong 已提交
118 119 120
  return true;
}

Y
Yu Yang 已提交
121
bool Chunk::Parse(std::istream& sin) {
Y
yuyang18 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
  ChunkParser parser(sin);
  if (!parser.Init()) {
    return false;
  }
  Clear();
  while (parser.HasNext()) {
    Add(parser.Next());
  }
  return true;
}

ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {}
bool ChunkParser::Init() {
  pos_ = 0;
  bool ok = header_.Parse(in_);
Y
Yu Yang 已提交
137 138 139
  if (!ok) {
    return ok;
  }
Y
yuyang18 已提交
140 141 142 143 144 145
  auto beg_pos = in_.tellg();
  uint32_t crc = Crc32Stream(in_, header_.CompressSize());
  PADDLE_ENFORCE_EQ(header_.Checksum(), crc);
  in_.seekg(beg_pos, in_.beg);

  switch (header_.CompressType()) {
D
dongzhihong 已提交
146 147 148
    case Compressor::kNoCompress:
      break;
    case Compressor::kSnappy:
Y
yuyang18 已提交
149
      compressed_stream_.reset(new snappy::iSnappyStream(in_));
D
dongzhihong 已提交
150
      break;
Y
Yu Yang 已提交
151 152
    default:
      PADDLE_THROW("Not implemented");
D
dongzhihong 已提交
153
  }
Y
yuyang18 已提交
154 155
  return true;
}
D
dongzhihong 已提交
156

Y
yuyang18 已提交
157
bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); }
Y
Yu Yang 已提交
158

Y
yuyang18 已提交
159 160 161
std::string ChunkParser::Next() {
  if (!HasNext()) {
    return "";
D
dongzhihong 已提交
162
  }
Y
yuyang18 已提交
163 164 165 166 167 168 169 170 171
  ++pos_;
  std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_;
  uint32_t rec_len;
  stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
  std::string buf;
  buf.resize(rec_len);
  stream.read(&buf[0], rec_len);
  PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
  return buf;
D
dongzhihong 已提交
172 173 174
}
}  // namespace recordio
}  // namespace paddle