serialization.h 1.1 KB
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4

#include <sstream>
#include <string>
D
dzhwinter 已提交
5
#include <type_traits>
6 7 8 9 10 11 12
#include "OptimizerConfig.pb.h"
#include "paddle/utils/Logging.h"
#include "tensor.h"

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13
static unsigned CalStateSize() { return 0; }
D
dzhwinter 已提交
14 15 16

template <typename HEAD, typename... TAIL>
unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
D
dzhwinter 已提交
17 18 19 20 21 22
  return sizeof head + CalStateSize(tail...);
}

template <typename... TAIL>
unsigned CalStateSize(const Tensor* head, const TAIL&... tail) {
  return head->size() + CalStateSize(tail...);
D
dzhwinter 已提交
23 24
}

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
  proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
  std::stringstream os;
  for (size_t i = 0; i < tensor.size(); ++i) {
    os << tensor[i];
    proto->add_content(os.str());
    os.clear();
  }
}

static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) {
  std::stringstream sin;
  for (auto i = 0; i < proto.content_size(); ++i) {
    sin << proto.content(i);
    sin >> (*tensor)[i];
    sin.clear();
  }
}

}  // namespace optimizer
}  // namespace paddle