selected_rows.h 2.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
13 14 15

#include <vector>

Y
Yi Wang 已提交
16 17
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
Q
qijun 已提交
18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace framework {

class SelectedRows {
 public:
  SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
      : rows_(rows), height_(height) {
    value_.reset(new Tensor());
  }

Q
QI JUN 已提交
29 30 31 32
  SelectedRows() {
    height_ = 0;
    value_.reset(new Tensor());
  }
Q
qijun 已提交
33 34 35

  platform::Place place() const { return value_->place(); }

Q
qijun 已提交
36 37 38
  const Tensor& value() const { return *value_; }

  Tensor* mutable_value() { return value_.get(); }
Q
qijun 已提交
39 40 41 42 43

  int64_t height() const { return height_; }

  void set_height(int64_t height) { height_ = height; }

Q
qijun 已提交
44
  const Vector<int64_t>& rows() const { return rows_; }
Q
qijun 已提交
45

Q
QI JUN 已提交
46 47
  Vector<int64_t>* mutable_rows() { return &rows_; }

Q
qijun 已提交
48
  void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
Q
qijun 已提交
49 50 51 52 53 54 55 56

  DDim GetCompleteDims() const {
    std::vector<int64_t> dims = vectorize(value_->dims());
    dims[0] = height_;
    return make_ddim(dims);
  }

 private:
Q
qijun 已提交
57
  // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
58
  // SelectedRows are simply concated when adding together. Until a
Q
qijun 已提交
59
  // SelectedRows add a Tensor, will the duplicate rows be handled.
Q
qijun 已提交
60
  Vector<int64_t> rows_;
Q
qijun 已提交
61 62 63 64
  std::unique_ptr<Tensor> value_{nullptr};
  int64_t height_;
};

65 66 67 68 69 70 71
/*
 * Serialize/Desiralize SelectedRows to std::ostream
 * You can pass ofstream or ostringstream to serilize to file
 * or to a in memory string. GPU tensor will be copied to CPU.
 */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
                       const platform::DeviceContext& dev_ctx);
Y
Yancey 已提交
72 73
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
                           const platform::DeviceContext& dev_ctx);
74

Q
qijun 已提交
75 76
}  // namespace framework
}  // namespace paddle