recordio.cc 2.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/recordio.h"
16

Y
Yu Yang 已提交
17
#include <fstream>
18 19 20
#include <string>
#include <vector>

Y
Yu Yang 已提交
21 22
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h"
Y
Refine  
Yu Yang 已提交
23

Y
Yu Yang 已提交
24 25 26
namespace paddle {
namespace pybind {

27 28
namespace {

Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
class RecordIOWriter {
 public:
  RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
                 size_t max_num_record)
      : stream_(filename), writer_(&stream_, compressor, max_num_record) {}

  void AppendTensor(const framework::LoDTensor& tensor) {
    tensors_.push_back(tensor);
  }

  void CompleteAppendTensor() {
    auto& ctx =
        *platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
F
fengjiayi 已提交
42
    framework::WriteToRecordIO(&writer_, tensors_, ctx);
Y
Yu Yang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    tensors_.clear();
  }

  void Close() {
    PADDLE_ENFORCE(tensors_.empty());
    writer_.Flush();
    stream_.close();
  }

 private:
  std::vector<framework::LoDTensor> tensors_;
  std::ofstream stream_;
  recordio::Writer writer_;
};

58 59 60 61
}  // namespace

void BindRecordIOWriter(py::module* m) {
  py::class_<RecordIOWriter> writer(*m, "RecordIOWriter", "");
Y
Yu Yang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  py::enum_<recordio::Compressor>(writer, "Compressor", "")
      .value("Snappy", recordio::Compressor::kSnappy)
      .value("NoCompress", recordio::Compressor::kNoCompress);

  writer
      .def("__init__",
           [](RecordIOWriter& self, const std::string& filename,
              recordio::Compressor compressor, size_t max_num_record) {
             new (&self) RecordIOWriter(filename, compressor, max_num_record);
           })
      .def("append_tensor", &RecordIOWriter::AppendTensor)
      .def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
      .def("close", &RecordIOWriter::Close);
}

}  // namespace pybind
}  // namespace paddle