serialization.h 1.2 KB
Newer Older
D
dzhwinter 已提交
1
#pragma once
2 3 4

#include <sstream>
#include <string>
D
dzhwinter 已提交
5
#include <type_traits>
6 7 8 9 10 11 12
#include "OptimizerConfig.pb.h"
#include "paddle/utils/Logging.h"
#include "tensor.h"

namespace paddle {
namespace optimizer {

D
dzhwinter 已提交
13 14 15 16 17 18 19 20 21 22 23
inline unsigned CalStateSize(int* state_len) { return 0; }

template <typename HEAD, typename... TAIL>
unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
  if (std::is_fundamental<HEAD>::value) {
    return sizeof head + CalStateSize(tail...);
  } else {
    return sizeof(head[0] * head->size()) + CalStateSize(tail...);
  }
}

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
static void TensorToProto(const Tensor& tensor, TensorProto* proto) {
  proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32);
  proto->set_size(tensor.size());
  std::stringstream os;
  for (size_t i = 0; i < tensor.size(); ++i) {
    os << tensor[i];
    proto->add_content(os.str());
    os.clear();
  }
}

static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) {
  CHECK(proto.size() == tensor->size()) << "unmatch shape of proto and tensor";
  std::stringstream sin;
  for (auto i = 0; i < proto.content_size(); ++i) {
    sin << proto.content(i);
    sin >> (*tensor)[i];
    sin.clear();
  }
}

}  // namespace optimizer
}  // namespace paddle