From 21d5ce57175e35e9df514e81f9c584f0bb2910ad Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Thu, 3 May 2018 05:16:26 +0000 Subject: [PATCH] Fix issue with big endian handling Signed-off-by: Stefan Weil --- src/lstm/input.cpp | 6 ++---- src/lstm/static_shape.h | 30 ++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/lstm/input.cpp b/src/lstm/input.cpp index 6411800a..fe47520a 100644 --- a/src/lstm/input.cpp +++ b/src/lstm/input.cpp @@ -42,14 +42,12 @@ Input::~Input() { // Writes to the given file. Returns false in case of error. bool Input::Serialize(TFile* fp) const { - if (!Network::Serialize(fp)) return false; - if (fp->FWrite(&shape_, sizeof(shape_), 1) != 1) return false; - return true; + return Network::Serialize(fp) && shape_.Serialize(fp); } // Reads from the given file. Returns false in case of error. bool Input::DeSerialize(TFile* fp) { - return fp->FReadEndian(&shape_, sizeof(shape_), 1) == 1; + return shape_.DeSerialize(fp); } // Returns an integer reduction factor that the network applies to the diff --git a/src/lstm/static_shape.h b/src/lstm/static_shape.h index 4822a5af..9ca35255 100644 --- a/src/lstm/static_shape.h +++ b/src/lstm/static_shape.h @@ -59,18 +59,40 @@ class StaticShape { height_, width_, depth_, loss_type_); } + bool DeSerialize(TFile *fp) { + int32_t tmp; + bool result = + fp->FReadEndian(&batch_, sizeof(batch_), 1) == 1 && + fp->FReadEndian(&height_, sizeof(height_), 1) == 1 && + fp->FReadEndian(&width_, sizeof(width_), 1) == 1 && + fp->FReadEndian(&depth_, sizeof(depth_), 1) == 1 && + fp->FReadEndian(&tmp, sizeof(tmp), 1) == 1; + loss_type_ = static_cast(tmp); + return result; + } + + bool Serialize(TFile *fp) const { + int32_t tmp = loss_type_; + return + fp->FWrite(&batch_, sizeof(batch_), 1) == 1 && + fp->FWrite(&height_, sizeof(height_), 1) == 1 && + fp->FWrite(&width_, sizeof(width_), 1) == 1 && + fp->FWrite(&depth_, sizeof(depth_), 1) == 1 && + fp->FWrite(&tmp, sizeof(tmp), 1) == 1; + } + private: // Size of the 4-D tensor input/output to a network. A value of zero is // allowed for all except depth_ and means to be determined at runtime, and // regarded as variable. // Number of elements in a batch, or number of frames in a video stream. - int batch_; + int32_t batch_; // Height of the image. - int height_; + int32_t height_; // Width of the image. - int width_; + int32_t width_; // Depth of the image. (Number of "nodes"). - int depth_; + int32_t depth_; // How to train/interpret the output. LossType loss_type_; }; -- GitLab