未验证 提交 a38db3cb 编写于 作者: W wopeizl 提交者: GitHub

Fixrecordio (#16124)

* fix recordio on win
test=develop

* test=develop

* test=develop

* fix code style
test=develop

* test=develop
上级 3c60446e
...@@ -31,7 +31,7 @@ class RecordIOWriter { ...@@ -31,7 +31,7 @@ class RecordIOWriter {
RecordIOWriter(const std::string& filename, recordio::Compressor compressor, RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
size_t max_num_record) size_t max_num_record)
: closed_(false), : closed_(false),
stream_(filename), stream_(filename, std::ios::binary),
writer_(&stream_, compressor, max_num_record) {} writer_(&stream_, compressor, max_num_record) {}
void AppendTensor(const framework::LoDTensor& tensor) { void AppendTensor(const framework::LoDTensor& tensor) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
#include <string> #include <string>
#include <utility>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -27,7 +28,8 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream) ...@@ -27,7 +28,8 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
} }
Scanner::Scanner(const std::string &filename) Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) { : stream_(new std::ifstream(filename, std::ios::in | std::ios::binary)),
parser_(*stream_) {
PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename); PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename);
Reset(); Reset();
} }
......
...@@ -26,8 +26,8 @@ class TestAccuracyOp(OpTest): ...@@ -26,8 +26,8 @@ class TestAccuracyOp(OpTest):
self.init_dtype() self.init_dtype()
n = 8192 n = 8192
infer = np.random.random((n, 1)).astype(self.dtype) infer = np.random.random((n, 1)).astype(self.dtype)
indices = np.random.randint(0, 2, (n, 1)) indices = np.random.randint(0, 2, (n, 1)).astype('int64')
label = np.random.randint(0, 2, (n, 1)) label = np.random.randint(0, 2, (n, 1)).astype('int64')
self.inputs = {'Out': infer, 'Indices': indices, "Label": label} self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
num_correct = 0 num_correct = 0
for rowid in range(n): for rowid in range(n):
......
...@@ -31,7 +31,7 @@ class TestRandomCropOp(OpTest): ...@@ -31,7 +31,7 @@ class TestRandomCropOp(OpTest):
np.array([[6, 7, 8], [10, 11, 12]]).astype(np.int32) np.array([[6, 7, 8], [10, 11, 12]]).astype(np.int32)
] ]
self.op_type = "random_crop" self.op_type = "random_crop"
self.inputs = {'X': to_crop, 'Seed': np.array([10])} self.inputs = {'X': to_crop, 'Seed': np.array([10]).astype('int64')}
self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])} self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])}
self.attrs = {'shape': [2, 3]} self.attrs = {'shape': [2, 3]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册