data_feed.cc 5.1 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdio.h>
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <utility>
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"

#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"

DEFINE_bool(is_text_feed, false, "is_text_feed");

namespace paddle {
namespace framework {
W
wangguibao 已提交
41
void TextClassDataFeed::Init() {
W
wangguibao 已提交
42
  // hard coding for a specific datafeed
W
wangguibao 已提交
43 44 45 46 47 48 49 50
  feed_vec_.resize(2);
  // feed_vec_[0].reset(new LoDTensor);
  // feed_vec_[1].reset(new LoDTensor);
  all_slot_ids_ = {0, 1};
  use_slot_ids_ = {0, 1};
  use_slot_alias_ = {"words", "label"};

  file_content_buffer_host_.reset(new char[200*1024*1024],
W
wangguibao 已提交
51
                                  [](char *p) {delete[] p;});
W
wangguibao 已提交
52 53 54 55
  file_content_buffer_ = file_content_buffer_host_.get();
  file_content_buffer_ptr_ = file_content_buffer_;

  batch_id_host_.reset(new int[10240*1024],
W
wangguibao 已提交
56
                      [](int *p) {delete[] p;});  // max word num in a batch
W
wangguibao 已提交
57 58 59
  batch_id_buffer_ = batch_id_host_.get();

  label_host_.reset(new int[10240],
W
wangguibao 已提交
60
                    [](int *p) {delete[] p;});    // max label in a batch
W
wangguibao 已提交
61
  label_ptr_ = label_host_.get();
W
wangguibao 已提交
62 63 64
}

  // todo: use elegant implemention for this function
W
wangguibao 已提交
65
bool TextClassDataFeed::ReadBatch() {
W
wangguibao 已提交
66 67 68 69
  paddle::framework::Vector<size_t> offset;
  int tlen = 0;
  int llen = 0;
  int inst_idx = 0;
W
wangguibao 已提交
70
  offset.resize(batch_size_ + 1);
W
wangguibao 已提交
71
  offset[0] = 0;
W
wangguibao 已提交
72
  while (inst_idx < batch_size_) {
W
wangguibao 已提交
73
    int ptr_offset = 0;
W
wangguibao 已提交
74
    if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
W
wangguibao 已提交
75 76 77 78
      break;
    }

    memcpy(reinterpret_cast<char *>(&llen),
W
wangguibao 已提交
79
          file_content_buffer_ptr_ + ptr_offset,
W
wangguibao 已提交
80 81 82
          sizeof(int));
    ptr_offset += sizeof(int);

W
wangguibao 已提交
83 84
    memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
          file_content_buffer_ptr_ + ptr_offset,
W
wangguibao 已提交
85 86 87 88 89 90
          llen * sizeof(int));
    tlen += llen;

    offset[inst_idx + 1] = offset[inst_idx] + llen;
    ptr_offset += sizeof(int) * llen;

W
wangguibao 已提交
91 92
    memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx),
          file_content_buffer_ptr_ + ptr_offset,
W
wangguibao 已提交
93 94 95
          sizeof(int));
    ptr_offset += sizeof(int);

W
wangguibao 已提交
96
    file_content_buffer_ptr_ += ptr_offset;
W
wangguibao 已提交
97 98 99
    inst_idx++;
  }

W
wangguibao 已提交
100
  if (inst_idx != batch_size_) {
W
wangguibao 已提交
101 102 103 104 105
    return false;
  }

  LoD input_lod{offset};
  paddle::framework::Vector<size_t> label_offset;
W
wangguibao 已提交
106 107
  label_offset.resize(batch_size_ + 1);
  for (int i = 0; i <= batch_size_; ++i) {
W
wangguibao 已提交
108 109 110 111
    label_offset[i] = i;
  }

  LoD label_lod{label_offset};
W
wangguibao 已提交
112
  int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>(
W
wangguibao 已提交
113 114
      {static_cast<int64_t>(offset.back()), 1},
      platform::CPUPlace());
W
wangguibao 已提交
115
  int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1},
W
wangguibao 已提交
116 117
                                                          platform::CPUPlace());
  for (unsigned int i = 0; i < offset.back(); ++i) {
W
wangguibao 已提交
118
    input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]);
W
wangguibao 已提交
119
  }
W
wangguibao 已提交
120 121
  for (int i = 0; i < batch_size_; ++i) {
    label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
W
wangguibao 已提交
122
  }
W
wangguibao 已提交
123 124
  feed_vec_[0]->set_lod(input_lod);
  feed_vec_[1]->set_lod(label_lod);
W
wangguibao 已提交
125 126 127
  return true;
}

W
wangguibao 已提交
128 129 130 131
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
  for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
    if (name == use_slot_alias_[i]) {
      feed_vec_[i] = feed->GetMutable<LoDTensor>();
W
wangguibao 已提交
132 133 134 135
    }
  }
}

W
wangguibao 已提交
136
bool TextClassDataFeed::SetFile(const char* filename) {
W
wangguibao 已提交
137
  // termnum termid termid ... termid label
W
wangguibao 已提交
138
  int filesize = ReadWholeFile(filename, file_content_buffer_);
W
wangguibao 已提交
139 140 141 142
  // todo , remove magic number
  if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
    return false;
  }
W
wangguibao 已提交
143 144
  file_content_buffer_ptr_ = file_content_buffer_;
  file_size_ = filesize;
W
wangguibao 已提交
145 146 147
  return true;
}

W
wangguibao 已提交
148
int TextClassDataFeed::ReadWholeFile(const std::string& filename,
W
wangguibao 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
                                       char* buffer) {
  std::ifstream ifs(filename.c_str(), std::ios::binary);
  if (ifs.fail()) {
    return -1;
  }

  ifs.seekg(0, std::ios::end);
  int file_size = ifs.tellg();
  ifs.seekg(0, std::ios::beg);
  ifs.read(buffer, file_size);
  return file_size;
}

}   // namespace framework
}   // namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */