From 0945dc1b9968f92a23bcedbb24bf68aacd194f26 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 17 Aug 2017 10:31:46 +0800 Subject: [PATCH] enable header format --- paddle/parameter/Parameter.cpp | 10 ++++++---- paddle/parameter/Parameter.h | 29 +++++++++++++++++++++++++++-- paddle/pserver/ParameterServer2.cpp | 7 ++++--- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index ebe36d493..f03110950 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit) deviceId_(-1), sharedCount_(0), updateCounter_(0), - updated_(false) { + updated_(false), + headerFormat_(PARAM_FORMAT_ORIGINAL) { setID(-1); /* capture uninitialized id */ if (useGpu_ && FLAGS_parallel_nn) { /* gpu environment is specified by device property */ @@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const { bool Parameter::save(std::ostream& s) const { CpuVector vec(*bufs_[PARAMETER_VALUE].get()); Header header; - header.version = kFormatVersion; + header.format = headerFormat_; header.valueSize = sizeof(real); header.size = getSize(); @@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) { Header header; CHECK(s.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameter " << getName(); - CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: " - << header.version; + CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: " + << header.format; + headerFormat_ = header.format; CHECK_EQ(header.size, getSize()) << "The size (" << header.size << ") in the file does not match the size " << "(" << getSize() << ") of the parameter: " << getName(); diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 0bac76f06..cffd3aa92 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -34,6 +34,12 @@ limitations under the License. */ namespace paddle { +typedef enum { + PARAM_FORMAT_ORIGINAL = 0, // the paddle original basic format + PARAM_FORMAT_MKLDNN_OI, // the mkldnn format oi + PARAM_FORMAT_ITEMS, // the total format items numbers +} PARAM_FORMAT; + class SparsePrefetchRowCpuMatrix; class Parameter; @@ -242,14 +248,30 @@ public: /// Initialize the value to 0 void zeroMem(); - static const int kFormatVersion = 0; /// file header structure struct Header { - int32_t version; // = 0, file format version + int32_t format; // = PARAM_FORMAT uint32_t valueSize; // = sizeof(real) uint64_t size; // = getSize() }; + /** + * @brief Is the header supported + */ + static bool isHeaderFormatSupported(int32_t fmt) { + return fmt < PARAM_FORMAT_ITEMS; + } + + /** + * @brief Get the format in header + */ + int getHeaderFormat() { return headerFormat_; } + + /** + * @brief Set the format in header + */ + void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } + /** * @brief Parameter Update Hook. * @@ -321,6 +343,9 @@ protected: bool updated_; SparseFormat format_; + // The header format for saving or loading param + int32_t headerFormat_; + std::vector> updaterHooks_; public: diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index d7c1d4f78..54f5c4c0f 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request, Parameter::Header header; CHECK(fs.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameters in pserver"; - CHECK_EQ(header.version, Parameter::kFormatVersion) - << "Incorrect format version: " << header.version; + CHECK(Parameter::isHeaderFormatSupported(header.format)) + << "Incorrect format version: " << header.format; CHECK_EQ(header.size, (size_t)size_) << "The size (" << header.size << ") in the file does not match the size " << "(" << size_ << ") of the pserver: " << serverId_; @@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request, CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY] : *vectors_[PARAMETER_VALUE]; Parameter::Header header; - header.version = Parameter::kFormatVersion; + // TODO(TJ): save param headerFormat_ + header.format = PARAM_FORMAT_ORIGINAL; header.valueSize = sizeof(real); header.size = size_; -- GitLab