diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md
index 2f4854793fa1f0b02e4dc17b51a48a972be61c06..6f414e5549b149bc88fb252085ff56dbb06730f8 100644
--- a/doc/fluid/design/dynamic_rnn/rnn.md
+++ b/doc/fluid/design/dynamic_rnn/rnn.md
@@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i
## RNN Algorithm Implementation
-
+
The above diagram shows an RNN unrolled into a full network.
@@ -22,7 +22,7 @@ There are several important concepts here:
There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step.
-
+
Figure 2 illustrates the RNN's data flow
@@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable.
### Usage in Python
-For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md).
+For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/block.md).
We can define an RNN's step-net using a Block:
@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par
The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text.
-
+
```python
@@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st
-
+
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index a688115b11af164319458207b19e915e8eaf676a..0b171e1dcfa90c3ad8f5a9ace8a9342baaf76e61 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map& feed_targets,
std::map& fetch_targets,
const std::string& feed_holder_name,
- const std::string& fetch_holder_name) {
+ const std::string& fetch_holder_name, bool create_vars) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
@@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
- Run(*copy_program, scope, 0, true, true);
+ Run(*copy_program, scope, 0, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) {
diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h
index fb29c70f1456eca7b46e779f737976f5f2da0682..d8dd82469af06a4c5c6a37d2249ee23413884a91 100644
--- a/paddle/fluid/framework/executor.h
+++ b/paddle/fluid/framework/executor.h
@@ -54,7 +54,8 @@ class Executor {
std::map& feed_targets,
std::map& fetch_targets,
const std::string& feed_holder_name = "feed",
- const std::string& fetch_holder_name = "fetch");
+ const std::string& fetch_holder_name = "fetch",
+ bool create_vars = true);
static std::unique_ptr Prepare(
const ProgramDesc& program, int block_id);
diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc
index 215ae229aff96d76fc948e19bdb42db319af65dc..5d27f5b60c7115a32aeeca5ec2a6654471c310c7 100644
--- a/paddle/fluid/operators/batch_norm_op.cc
+++ b/paddle/fluid/operators/batch_norm_op.cc
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext &ctx) const override {
+ auto input_data_type =
+ framework::ToDataType(ctx.Input("X")->type());
+ // For float or float16 input tensor, the type of the scale, bias, mean,
+ // and var tensors should both be float.
+ auto bn_param_type = framework::proto::VarType::FP32;
+ PADDLE_ENFORCE_EQ(bn_param_type,
+ framework::ToDataType(ctx.Input("Scale")->type()),
+ "Scale input should be of float type");
+ PADDLE_ENFORCE_EQ(bn_param_type,
+ framework::ToDataType(ctx.Input("Bias")->type()),
+ "Bias input should be of float type");
+ PADDLE_ENFORCE_EQ(bn_param_type,
+ framework::ToDataType(ctx.Input("Mean")->type()),
+ "Mean input should be of float type");
+ PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
+ ctx.Input("Variance")->type()),
+ "Variance input should be of float type");
+ return framework::OpKernelType(input_data_type, ctx.GetPlace());
+ }
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc
index 2d1556efc66826ea9847de8311ccecdee0ea7871..6ceacc39924a7558e380aaf563aaf234f1bf30a5 100644
--- a/paddle/fluid/operators/batch_norm_op.cu.cc
+++ b/paddle/fluid/operators/batch_norm_op.cu.cc
@@ -18,6 +18,7 @@ limitations under the License. */
#include
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
+#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template
using CudnnDataType = platform::CudnnDataType;
+template
+using BatchNormParamType = typename CudnnDataType::BatchNormParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) {
@@ -104,8 +107,9 @@ class BatchNormKernel
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
+ // Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
- bn_param_desc_, data_desc_, mode_));
+ bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input("Scale");
const auto *bias = ctx.Input("Bias");
@@ -118,15 +122,16 @@ class BatchNormKernel
// alloc memory
y->mutable_data(ctx.GetPlace());
- mean_out->mutable_data(ctx.GetPlace());
- variance_out->mutable_data(ctx.GetPlace());
- saved_mean->mutable_data(ctx.GetPlace());
- saved_variance->mutable_data(ctx.GetPlace());
+ mean_out->mutable_data>(ctx.GetPlace());
+ variance_out->mutable_data>(ctx.GetPlace());
+ saved_mean->mutable_data>(ctx.GetPlace());
+ saved_variance->mutable_data>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context();
- math::SetConstant functor;
- functor(dev_ctx, saved_mean, 0);
- functor(dev_ctx, saved_variance, 0);
+ math::SetConstant>
+ functor;
+ functor(dev_ctx, saved_mean, static_cast>(0));
+ functor(dev_ctx, saved_variance, static_cast>(0));
auto handle = dev_ctx.cudnn_handle();
@@ -147,8 +152,10 @@ class BatchNormKernel
CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(),
CudnnDataType::kZero(), data_desc_, x->template data(),
data_desc_, y->template mutable_data(ctx.GetPlace()),
- bn_param_desc_, scale->template data(), bias->template data(),
- est_mean->template data(), est_var->template data(), epsilon));
+ bn_param_desc_, scale->template data>(),
+ bias->template data>(),
+ est_mean->template data>(),
+ est_var->template data>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
@@ -159,11 +166,16 @@ class BatchNormKernel
handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(),
data_desc_, x->template data(), data_desc_,
y->template mutable_data(ctx.GetPlace()), bn_param_desc_,
- scale->template data(), bias->template data(), this_factor,
- mean_out->template mutable_data(ctx.GetPlace()),
- variance_out->template mutable_data(ctx.GetPlace()), epsilon,
- saved_mean->template mutable_data(ctx.GetPlace()),
- saved_variance->template mutable_data(ctx.GetPlace())));
+ scale->template data>(),
+ bias->template data>(), this_factor,
+ mean_out->template mutable_data>(
+ ctx.GetPlace()),
+ variance_out->template mutable_data>(
+ ctx.GetPlace()),
+ epsilon, saved_mean->template mutable_data>(
+ ctx.GetPlace()),
+ saved_variance->template mutable_data>(
+ ctx.GetPlace())));
}
// clean when exit.
@@ -270,9 +282,9 @@ class BatchNormGradKernel
} // namespace paddle
namespace ops = paddle::operators;
+namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
- batch_norm,
- ops::BatchNormKernel);
+ batch_norm, ops::BatchNormKernel,
+ ops::BatchNormKernel);
REGISTER_OP_CUDA_KERNEL(
- batch_norm_grad,
- ops::BatchNormGradKernel);
+ batch_norm_grad, ops::BatchNormGradKernel);
diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc
index 17e576a9d5c8f50fbe84b066a93460f03ae6bb08..299a0aed01dfe0448d896738d9fd33319b1b2887 100644
--- a/paddle/fluid/operators/math/math_function.cc
+++ b/paddle/fluid/operators/math/math_function.cc
@@ -278,6 +278,7 @@ void axpy(
cblas_daxpy(n, alpha, x, 1, y, 1);
}
+template struct SetConstant;
template struct SetConstant;
template struct SetConstant;
template struct SetConstant;
diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu
index c6ca2693a053360ce5dc44765acf1520a11cce2c..1e909db5288afccb9dd0be08a45cf3c27048ae6f 100644
--- a/paddle/fluid/operators/math/math_function.cu
+++ b/paddle/fluid/operators/math/math_function.cu
@@ -348,6 +348,7 @@ void axpy(
&alpha, x, 1, y, 1));
}
+template struct SetConstant;
template struct SetConstant;
template struct SetConstant;
template struct SetConstant;
diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt
index 744bd3b7ef71f83ad82979eb966369c2e9456a7d..6fa0195b9ae103418beb56cc4b0fa9ab59e93108 100644
--- a/paddle/fluid/operators/reader/CMakeLists.txt
+++ b/paddle/fluid/operators/reader/CMakeLists.txt
@@ -15,10 +15,12 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE)
endfunction()
+reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
+reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
index bd0bb2ee3b0252f47318c59d9940d8dd478723de..76cdb794ccdb4a015ae8630940a5c26845e7a7b3 100644
--- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
+++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void DoubleBufferReader::ReadNext(std::vector* out) {
+ if (!HasNext()) {
+ PADDLE_THROW("There is no next data!");
+ }
+
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
-
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..4d4e9fb909eafea5328491a4097276577f28a5ba
--- /dev/null
+++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/detail/safe_ref.h"
+#include "paddle/fluid/operators/reader/reader_op_registry.h"
+
+namespace paddle {
+namespace operators {
+namespace reader {
+
+class MultiPassReader : public framework::DecoratedReader {
+ public:
+ MultiPassReader(ReaderBase* reader, int pass_num)
+ : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
+
+ void ReadNext(std::vector* out) override {
+ if (!HasNext()) {
+ PADDLE_THROW("There is no next data!");
+ }
+ reader_->ReadNext(out);
+ }
+
+ bool HasNext() const override {
+ if (reader_->HasNext()) {
+ return true;
+ } else {
+ ++pass_count_;
+ if (pass_count_ >= pass_num_) {
+ return false;
+ } else {
+ reader_->ReInit();
+ return true;
+ }
+ }
+ }
+
+ void ReInit() override {
+ pass_count_ = 0;
+ reader_->ReInit();
+ }
+
+ private:
+ int pass_num_;
+ mutable int pass_count_;
+};
+
+class CreateMultiPassReaderOp : public framework::OperatorBase {
+ public:
+ using framework::OperatorBase::OperatorBase;
+
+ private:
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
+ ->Get();
+ auto& out = detail::Ref(scope.FindVar(Output("Out")));
+ int pass_num = Attr("pass_num");
+ out.GetMutable()->Reset(
+ new MultiPassReader(underlying_reader.Get(), pass_num));
+ }
+};
+
+class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
+ public:
+ CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
+ : DecoratedReaderMakerBase(op_proto, op_checker) {
+ AddAttr("pass_num", "The number of pass to run.").GreaterThan(0);
+ AddComment(R"DOC(
+ CreateMultiPassReader Operator
+
+ This operator creates a multi-pass reader. A multi-pass reader
+ is used to yield data for several pass training continuously.
+ It takes the the number of pass to run as one of its attributes
+ ('pass_num'), and maintains a pass counter to record how many
+ passes it has completed. When the underlying reader reach the EOF,
+ the multi-pass reader checks whether it has completed training
+ of the given number of pass. If not, the underlying reader will
+ be re-initialized and starts a new pass automatically.
+ )DOC");
+ }
+};
+
+} // namespace reader
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators::reader;
+REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
+ ops::CreateMultiPassReaderOp,
+ ops::CreateMultiPassReaderOpMaker);
diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..414c76fea0bb916dfeafe38c0448a7a800889e03
--- /dev/null
+++ b/paddle/fluid/operators/reader/open_files_op.cc
@@ -0,0 +1,212 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/channel.h"
+#include "paddle/fluid/operators/reader/reader_op_registry.h"
+
+namespace paddle {
+namespace operators {
+namespace reader {
+
+class MultipleReader : public framework::ReaderBase {
+ public:
+ MultipleReader(const std::vector& file_names,
+ const std::vector& dims, size_t thread_num)
+ : file_names_(file_names), dims_(dims) {
+ prefetchers_.resize(thread_num);
+ StartNewScheduler();
+ }
+
+ void ReadNext(std::vector* out) override;
+ bool HasNext() const override;
+ void ReInit() override;
+
+ ~MultipleReader() { EndScheduler(); }
+
+ private:
+ void StartNewScheduler();
+ void EndScheduler();
+ void ScheduleThreadFunc();
+ void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
+
+ std::vector file_names_;
+ std::vector dims_;
+ std::thread scheduler_;
+ std::vector prefetchers_;
+ framework::Channel* waiting_file_idx_;
+ framework::Channel* available_thread_idx_;
+ framework::Channel>* buffer_;
+ mutable std::vector local_buffer_;
+};
+
+void MultipleReader::ReadNext(std::vector* out) {
+ if (!HasNext()) {
+ PADDLE_THROW("There is no next data!");
+ }
+
+ if (local_buffer_.empty()) {
+ buffer_->Receive(&local_buffer_);
+ }
+ *out = local_buffer_;
+ local_buffer_.clear();
+}
+
+bool MultipleReader::HasNext() const {
+ return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
+}
+
+void MultipleReader::ReInit() {
+ EndScheduler();
+ local_buffer_.clear();
+ StartNewScheduler();
+}
+
+void MultipleReader::StartNewScheduler() {
+ size_t thread_num = prefetchers_.size();
+ waiting_file_idx_ = framework::MakeChannel(file_names_.size());
+ available_thread_idx_ = framework::MakeChannel(thread_num);
+ buffer_ =
+ framework::MakeChannel>(thread_num);
+
+ for (size_t i = 0; i < file_names_.size(); ++i) {
+ waiting_file_idx_->Send(&i);
+ }
+ waiting_file_idx_->Close();
+ for (size_t i = 0; i < thread_num; ++i) {
+ available_thread_idx_->Send(&i);
+ }
+
+ scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
+}
+
+void MultipleReader::EndScheduler() {
+ available_thread_idx_->Close();
+ buffer_->Close();
+ waiting_file_idx_->Close();
+ if (scheduler_.joinable()) {
+ scheduler_.join();
+ }
+ delete buffer_;
+ delete available_thread_idx_;
+ delete waiting_file_idx_;
+}
+
+void MultipleReader::ScheduleThreadFunc() {
+ VLOG(5) << "MultipleReader schedule thread starts.";
+ size_t completed_thread_num = 0;
+ size_t thread_idx;
+ while (available_thread_idx_->Receive(&thread_idx)) {
+ std::thread& prefetcher = prefetchers_[thread_idx];
+ if (prefetcher.joinable()) {
+ prefetcher.join();
+ }
+ size_t file_idx;
+ if (waiting_file_idx_->Receive(&file_idx)) {
+ // Still have files to read. Start a new prefetch thread.
+ std::string file_name = file_names_[file_idx];
+ prefetcher = std::thread([this, file_name, thread_idx] {
+ PrefetchThreadFunc(file_name, thread_idx);
+ });
+ } else {
+ // No more file to read.
+ ++completed_thread_num;
+ if (completed_thread_num == prefetchers_.size()) {
+ buffer_->Close();
+ break;
+ }
+ }
+ }
+ // If users invoke ReInit() when scheduler is running, it will close the
+ // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
+ // to release their resource. So a check is needed before scheduler ends.
+ for (auto& p : prefetchers_) {
+ if (p.joinable()) {
+ p.join();
+ }
+ }
+ VLOG(5) << "MultipleReader schedule thread terminates.";
+}
+
+void MultipleReader::PrefetchThreadFunc(std::string file_name,
+ size_t thread_idx) {
+ VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
+ std::unique_ptr reader =
+ CreateReaderByFileName(file_name, dims_);
+ while (reader->HasNext()) {
+ std::vector ins;
+ reader->ReadNext(&ins);
+ if (!buffer_->Send(&ins)) {
+ VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
+ "thread of file '"
+ << file_name << "' will terminate.";
+ break;
+ }
+ }
+ if (!available_thread_idx_->Send(&thread_idx)) {
+ VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
+ "Fail to send thread_idx.";
+ }
+ VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
+}
+
+class OpenFilesOp : public framework::OperatorBase {
+ public:
+ using framework::OperatorBase::OperatorBase;
+
+ private:
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ const auto& shape_concat = Attr>("shape_concat");
+ const auto& ranks = Attr>("ranks");
+ PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
+ PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
+ int(shape_concat.size()),
+ "The accumulate of all ranks should be equal to the "
+ "shape concat's length.");
+ const auto& file_names = Attr>("file_names");
+ PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
+ const size_t thread_num = Attr("thread_num");
+
+ auto* out = scope.FindVar(Output("Out"))
+ ->template GetMutable();
+ out->Reset(new MultipleReader(
+ file_names, RestoreShapes(shape_concat, ranks), thread_num));
+ }
+};
+
+class OpenFilesOpMaker : public FileReaderMakerBase {
+ public:
+ OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
+ : FileReaderMakerBase(op_proto, op_checker) {
+ AddAttr>("file_names", "Files to be read.");
+ AddAttr("thread_num", "The maximal concurrent prefetch thread number.")
+ .GreaterThan(0);
+
+ AddComment(R"DOC(
+ OpenFiles Operator
+
+ An OpenFilesOp creates a MultipleReader, which is able to
+ read data multi-threaded from multiple files.
+ )DOC");
+ }
+};
+
+} // namespace reader
+} // namespace operators
+} // namespace paddle
+
+namespace reader = paddle::operators::reader;
+
+REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
+ reader::OpenFilesOpMaker);
diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc
index 0ba4f3854431742eb354f8c90eb395f5d7b32b2e..fc8dc747ff0c2286f4516d8350f75d9887361924 100644
--- a/paddle/fluid/operators/reader/reader_op_registry.cc
+++ b/paddle/fluid/operators/reader/reader_op_registry.cc
@@ -36,6 +36,21 @@ std::unordered_map& FileReaderRegistry() {
return regs;
}
+std::unique_ptr CreateReaderByFileName(
+ const std::string& file_name, const std::vector& dims) {
+ size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
+ PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
+ "File name illegal! A legal file name should be like: "
+ "[file_name].[file_format] (e.g., 'data_file.recordio').");
+ std::string filetype = file_name.substr(separator_pos + 1);
+
+ auto itor = FileReaderRegistry().find(filetype);
+ PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
+ "No file reader registered for '%s' format.", filetype);
+ framework::ReaderBase* reader = (itor->second)(file_name, dims);
+ return std::unique_ptr(reader);
+}
+
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h
index 58f9b4ba35546571fd3b1d0c3ce128f18e248f01..929d32ad8b367865e33530f8517343c513ee9878 100644
--- a/paddle/fluid/operators/reader/reader_op_registry.h
+++ b/paddle/fluid/operators/reader/reader_op_registry.h
@@ -21,6 +21,8 @@ namespace paddle {
namespace operators {
namespace reader {
+static constexpr char kFileFormatSeparator[] = ".";
+
using FileReaderCreator = std::function&)>;
@@ -29,12 +31,15 @@ std::unordered_map& FileReaderRegistry();
template
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
- const std::string& fn, const std::vector& dim) {
- return new Reader(fn, dim);
+ const std::string& fn, const std::vector& dims) {
+ return new Reader(fn, dims);
};
return 0;
}
+std::unique_ptr CreateReaderByFileName(
+ const std::string& file_name, const std::vector& dims);
+
extern std::vector RestoreShapes(
const std::vector& shape_concat, const std::vector& ranks);
diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h
index 7e001ecc56173db76e8c576e7efd66f41192f292..7c604e14eb245232ed92f53a00b9bde45c2fbaec 100644
--- a/paddle/fluid/platform/cudnn_helper.h
+++ b/paddle/fluid/platform/cudnn_helper.h
@@ -86,7 +86,8 @@ class CudnnDataType {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors
- typedef const float ScalingParamType;
+ using ScalingParamType = const float;
+ using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
@@ -101,7 +102,8 @@ template <>
class CudnnDataType {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
- typedef const float ScalingParamType;
+ using ScalingParamType = const float;
+ using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
@@ -116,7 +118,8 @@ template <>
class CudnnDataType {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
- typedef const double ScalingParamType;
+ using ScalingParamType = const double;
+ using BatchNormParamType = double;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc
index e50de15b7c2b480357f5f6c7daa2b4a676749679..ed09d58f6a3e2dba50bf4407c0463480575b248e 100644
--- a/paddle/fluid/recordio/header.cc
+++ b/paddle/fluid/recordio/header.cc
@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
bool Header::Parse(std::istream& is) {
uint32_t magic;
- size_t read_size =
- is.readsome(reinterpret_cast(&magic), sizeof(uint32_t));
+ is.read(reinterpret_cast(&magic), sizeof(uint32_t));
+ size_t read_size = is.gcount();
if (read_size < sizeof(uint32_t)) {
return false;
}
diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc
index d842f8fe5a4c9d1a2b564c738d97fffb02f3ccb5..c22281dc97e05173ad76ce76959833b92f11c4ee 100644
--- a/paddle/fluid/recordio/scanner.cc
+++ b/paddle/fluid/recordio/scanner.cc
@@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) {
}
void Scanner::Reset() {
+ stream_->clear();
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
}
diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py
index 9c91f395e7c9d7ca76c1a5cc310bc3bbc06daec9..bc5e291ad811315ddc9d101853d69c7f5ab5082d 100644
--- a/python/paddle/fluid/layers/io.py
+++ b/python/paddle/fluid/layers/io.py
@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
- 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
+ 'open_files', 'read_file', 'create_shuffle_reader',
+ 'create_double_buffer_reader', 'create_multi_pass_reader'
]
@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)
+def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
+ dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
+ shape_concat = []
+ ranks = []
+
+ for shape in shapes:
+ shape_concat.extend(shape)
+ ranks.append(len(shape))
+
+ var_name = unique_name('multiple_reader')
+
+ startup_blk = default_startup_program().current_block()
+ startup_var = startup_blk.create_var(name=var_name)
+ startup_blk.append_op(
+ type='open_files',
+ outputs={'Out': [startup_var]},
+ attrs={
+ 'shape_concat': shape_concat,
+ 'lod_levels': lod_levels,
+ 'ranks': ranks,
+ 'file_names': filenames,
+ 'thread_num': thread_num
+ })
+
+ startup_var.desc.set_dtypes(dtypes)
+ startup_var.persistable = True
+ return _copy_reader_var_(default_main_program().current_block(),
+ startup_var)
+
+
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
@@ -314,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
attrs)
+def create_multi_pass_reader(reader, pass_num):
+ return __create_decorated_reader__('create_multi_pass_reader', reader,
+ {'pass_num': int(pass_num)})
+
+
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore
index 6b3fc2a83c649c28d21c9a8a0b35c2f2fa04f269..ad02bdecf436bba925e2e3b7efb20c878df70dfd 100644
--- a/python/paddle/fluid/tests/unittests/.gitignore
+++ b/python/paddle/fluid/tests/unittests/.gitignore
@@ -1 +1,4 @@
mnist.recordio
+mnist_0.recordio
+mnist_1.recordio
+mnist_2.recordio
diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
index 80e6fa6df3c21aa19feb571916f11c41ccd6bb10..10aa63e18a6eeaa44e5b12f7532998dca2bc5e9f 100644
--- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op
+def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
+ x_shape = x.shape
+ if len(x_shape) == 2:
+ if data_format == "NCHW":
+ x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
+ else:
+ x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
+
+ if data_format == "NCHW":
+ n, c, h, w = x.shape
+ mean_tile = np.reshape(mean, (1, c, 1, 1))
+ mean_tile = np.tile(mean_tile, (n, 1, h, w))
+ var_tile = np.reshape(var, (1, c, 1, 1))
+ var_tile = np.tile(var_tile, (n, 1, h, w))
+ normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
+ scale_tile = np.reshape(scale, (1, c, 1, 1))
+ scale_tile = np.tile(scale_tile, (n, 1, h, w))
+ offset_tile = np.reshape(offset, (1, c, 1, 1))
+ offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
+ y = normalized * scale_tile + offset_tile
+ elif data_format == "NHWC":
+ normalized = (x - mean) / np.sqrt(var + epsilon)
+ y = normalized * scale + offset
+ else:
+ raise ValueError("Unknown data order.")
+
+ if len(x_shape) == 2:
+ y = np.reshape(y, x_shape)
+ return y
+
+
def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
__set_tensor__(output, data)
-class TestBatchNormOp(OpTest):
+class TestBatchNormOpInference(OpTest):
+ def setUp(self):
+ self.dtype = np.float32
+
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
- def test_python(self):
+ def check_with_place(self, place, data_layout, dtype, shape):
+ epsilon = 0.00001
+ if len(shape) == 2:
+ x_shape = shape
+ c = x_shape[1]
+ else:
+ n, h, w, c = shape[0], shape[1], shape[2], shape[3]
+ if data_layout == "NHWC":
+ x_shape = [n, h, w, c]
+ elif data_layout == "NCHW":
+ x_shape = [n, c, h, w]
+ else:
+ raise ValueError("Unknown data layout.")
+ scale_shape = [c]
+
+ x_val = np.random.random_sample(x_shape).astype(dtype)
+ scale_val = np.random.random_sample(scale_shape).astype(np.float32)
+ bias_val = np.random.random_sample(scale_shape).astype(np.float32)
+
+ mean = np.zeros(scale_shape).astype(np.float32)
+ variance = np.ones(scale_shape).astype(np.float32)
+
+ y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
+ epsilon, data_layout).astype(dtype)
+
+ scope = core.Scope()
+
+ # create input
+ x_tensor = create_or_get_tensor(scope, "x_val",
+ OpTest.np_dtype_to_fluid_dtype(x_val),
+ place)
+ scale_tensor = create_or_get_tensor(
+ scope, "scale_val",
+ OpTest.np_dtype_to_fluid_dtype(scale_val), place)
+ bias_tensor = create_or_get_tensor(
+ scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
+ mean_tensor = create_or_get_tensor(scope, "mean",
+ OpTest.np_dtype_to_fluid_dtype(mean),
+ place)
+ variance_tensor = create_or_get_tensor(
+ scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
+
+ # create output
+ y_tensor = create_or_get_tensor(scope, "y_out", None, place)
+ saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
+ place)
+ saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
+ None, place)
+ mean_out_tensor = mean_tensor
+ variance_out_tensor = variance_tensor
+
+ batch_norm_op = Operator(
+ "batch_norm",
+ # inputs
+ X="x_val",
+ Scale="scale_val",
+ Bias="bias_val",
+ Mean="mean",
+ Variance="variance",
+ # outputs
+ Y="y_out",
+ MeanOut="mean",
+ VarianceOut="variance",
+ SavedMean="saved_mean",
+ SavedVariance="saved_variance",
+ # attrs
+ is_test=True,
+ data_layout=data_layout,
+ epsilon=epsilon)
+
+ batch_norm_op.run(scope, place)
+
+ # check inference result
+ self.__assert_close(
+ y_tensor,
+ y_out,
+ "inference output are different at " + str(place) + ", " +
+ data_layout + ", " + str(np.dtype(dtype)) +
+ str(np.array(y_tensor)) + str(y_out),
+ atol=1e-3)
+
+ def test_check_output(self):
+ places = [core.CPUPlace()]
+ if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
+ places.append(core.CUDAPlace(0))
+
+ for place in places:
+ for data_format in ["NCHW", "NHWC"]:
+ self.check_with_place(place, data_format, self.dtype,
+ [2, 3, 4, 5])
+ self.check_with_place(place, data_format, self.dtype, [2, 3])
+
+
+class TestFP16BatchNormOpInference(TestBatchNormOpInference):
+ def setUp(self):
+ self.dtype = np.float16
+
+ def test_check_output(self):
+ places = []
+ if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
+ place = core.CUDAPlace(0)
+ if core.is_float16_supported(place):
+ places.append(place)
+
+ for place in places:
+ for data_format in ["NCHW", "NHWC"]:
+ self.check_with_place(place, data_format, self.dtype,
+ [2, 3, 4, 5])
+ self.check_with_place(place, data_format, self.dtype, [2, 3])
+
+
+class TestBatchNormOpTraining(OpTest):
+ def __assert_close(self, tensor, np_array, msg, atol=1e-4):
+ self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
+
+ def test_python_testing(self):
+ data_format = "NHWC"
+ epsilon = 0.00001
+
+ n, h, w, c = 2, 3, 4, 5
+ x_shape = [n, h, w, c]
+ scale_shape = [c]
+
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ scale_val = np.random.random_sample(scale_shape).astype(np.float32)
+ bias_val = np.random.random_sample(scale_shape).astype(np.float32)
+
+ mean = np.zeros(scale_shape).astype(np.float32)
+ variance = np.ones(scale_shape).astype(np.float32)
+
+ y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
+ epsilon, "NHWC")
+
+ # running N, C, H, W case
+ # should produce the same results
+ x_shape2 = [n, c, h, w]
+ x_val2 = np.transpose(x_val, (0, 3, 1, 2))
+ y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
+ epsilon, "NCHW")
+
+ # transfer (N, C, H, W) back to (N, H, W, C)
+ y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
+ self.__assert_close(y_out, y_out2_trans, "inference output")
+ print 'python: NHWC, NCHW, inference checking passed'
+
+ def test_python_training(self):
data_format = "NHWC"
epsilon = 0.00001
momentum = 0.9
@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
- self.__assert_close(y_out, y_out2_trans, "batch variance")
+ self.__assert_close(y_out, y_out2_trans, "batch output")
print 'python: NHWC, NCHW, forward checking passed'
# test backward now
diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..8add353303e3626bbce68199a100306d4858766a
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+import paddle.v2.dataset.mnist as mnist
+
+
+class TestMultipleReader(unittest.TestCase):
+ def setUp(self):
+ self.batch_size = 64
+ self.pass_num = 3
+ # Convert mnist to recordio file
+ with fluid.program_guard(fluid.Program(), fluid.Program()):
+ data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
+ feeder = fluid.DataFeeder(
+ feed_list=[
+ fluid.layers.data(
+ name='image', shape=[784]),
+ fluid.layers.data(
+ name='label', shape=[1], dtype='int64'),
+ ],
+ place=fluid.CPUPlace())
+ self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
+ './mnist.recordio', data_file, feeder)
+
+ def test_main(self):
+ with fluid.program_guard(fluid.Program(), fluid.Program()):
+ data_file = fluid.layers.open_recordio_file(
+ filename='./mnist.recordio',
+ shapes=[(-1, 784), (-1, 1)],
+ lod_levels=[0, 0],
+ dtypes=['float32', 'int64'])
+ data_file = fluid.layers.create_multi_pass_reader(
+ reader=data_file, pass_num=self.pass_num)
+ img, label = fluid.layers.read_file(data_file)
+
+ if fluid.core.is_compiled_with_cuda():
+ place = fluid.CUDAPlace(0)
+ else:
+ place = fluid.CPUPlace()
+
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ batch_count = 0
+ while not data_file.eof():
+ img_val, = exe.run(fetch_list=[img])
+ batch_count += 1
+ self.assertLessEqual(img_val.shape[0], self.batch_size)
+ data_file.reset()
+ self.assertEqual(batch_count, self.num_batch * self.pass_num)
diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f8acf81efaba8fc0f3df4cfe3a42dc4e477df2
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+import paddle.v2.dataset.mnist as mnist
+from shutil import copyfile
+
+
+class TestMultipleReader(unittest.TestCase):
+ def setUp(self):
+ self.batch_size = 64
+ # Convert mnist to recordio file
+ with fluid.program_guard(fluid.Program(), fluid.Program()):
+ reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
+ feeder = fluid.DataFeeder(
+ feed_list=[ # order is image and label
+ fluid.layers.data(
+ name='image', shape=[784]),
+ fluid.layers.data(
+ name='label', shape=[1], dtype='int64'),
+ ],
+ place=fluid.CPUPlace())
+ self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
+ './mnist_0.recordio', reader, feeder)
+ copyfile('./mnist_0.recordio', './mnist_1.recordio')
+ copyfile('./mnist_0.recordio', './mnist_2.recordio')
+
+ def main(self, thread_num):
+ file_list = [
+ './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
+ ]
+ with fluid.program_guard(fluid.Program(), fluid.Program()):
+ data_files = fluid.layers.open_files(
+ filenames=file_list,
+ thread_num=thread_num,
+ shapes=[(-1, 784), (-1, 1)],
+ lod_levels=[0, 0],
+ dtypes=['float32', 'int64'])
+ img, label = fluid.layers.read_file(data_files)
+
+ if fluid.core.is_compiled_with_cuda():
+ place = fluid.CUDAPlace(0)
+ else:
+ place = fluid.CPUPlace()
+
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ batch_count = 0
+ while not data_files.eof():
+ img_val, = exe.run(fetch_list=[img])
+ batch_count += 1
+ self.assertLessEqual(img_val.shape[0], self.batch_size)
+ data_files.reset()
+ self.assertEqual(batch_count, self.num_batch * 3)
+
+ def test_main(self):
+ self.main(thread_num=3) # thread number equals to file number
+ self.main(thread_num=10) # thread number is larger than file number
+ self.main(thread_num=2) # thread number is less than file number