diff --git a/doc/fluid/design/dynamic_rnn/rnn.md b/doc/fluid/design/dynamic_rnn/rnn.md
index 2f4854793fa1f0b02e4dc17b51a48a972be61c06..6f414e5549b149bc88fb252085ff56dbb06730f8 100644
--- a/doc/fluid/design/dynamic_rnn/rnn.md
+++ b/doc/fluid/design/dynamic_rnn/rnn.md
@@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i
 ## RNN Algorithm Implementation
 
 
- +
+ 
 
 
 The above diagram shows an RNN unrolled into a full network.
@@ -22,7 +22,7 @@ There are several important concepts here:
 There could be local variables defined in each step-net.  PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step.
 
 
-
+
 Figure 2 illustrates the RNN's data flow
 
 
@@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable.
 
 ### Usage in Python
 
-For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md).
+For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/block.md).
 
 We can define an RNN's step-net using a Block:
 
@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par
 The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text.
 
 
- +
+ 
 
 
 ```python
@@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st
 
 
 
- +
+ 
 
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index a688115b11af164319458207b19e915e8eaf676a..0b171e1dcfa90c3ad8f5a9ace8a9342baaf76e61 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
                    std::map& feed_targets,
                    std::map& fetch_targets,
                    const std::string& feed_holder_name,
-                   const std::string& fetch_holder_name) {
+                   const std::string& fetch_holder_name, bool create_vars) {
   platform::RecordBlock b(kProgramId);
   bool has_feed_ops =
       has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
@@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
     }
   }
 
-  Run(*copy_program, scope, 0, true, true);
+  Run(*copy_program, scope, 0, create_vars, create_vars);
 
   // obtain the data of fetch_targets from fetch_holder
   for (auto* op : global_block->AllOps()) {
diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h
index fb29c70f1456eca7b46e779f737976f5f2da0682..d8dd82469af06a4c5c6a37d2249ee23413884a91 100644
--- a/paddle/fluid/framework/executor.h
+++ b/paddle/fluid/framework/executor.h
@@ -54,7 +54,8 @@ class Executor {
            std::map& feed_targets,
            std::map& fetch_targets,
            const std::string& feed_holder_name = "feed",
-           const std::string& fetch_holder_name = "fetch");
+           const std::string& fetch_holder_name = "fetch",
+           bool create_vars = true);
 
   static std::unique_ptr Prepare(
       const ProgramDesc& program, int block_id);
diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc
index 215ae229aff96d76fc948e19bdb42db319af65dc..5d27f5b60c7115a32aeeca5ec2a6654471c310c7 100644
--- a/paddle/fluid/operators/batch_norm_op.cc
+++ b/paddle/fluid/operators/batch_norm_op.cc
@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
     ctx->SetOutputDim("SavedVariance", {C});
     ctx->ShareLoD("X", "Y");
   }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext &ctx) const override {
+    auto input_data_type =
+        framework::ToDataType(ctx.Input("X")->type());
+    // For float or float16 input tensor, the type of the scale, bias, mean,
+    // and var tensors should both be float.
+    auto bn_param_type = framework::proto::VarType::FP32;
+    PADDLE_ENFORCE_EQ(bn_param_type,
+                      framework::ToDataType(ctx.Input("Scale")->type()),
+                      "Scale input should be of float type");
+    PADDLE_ENFORCE_EQ(bn_param_type,
+                      framework::ToDataType(ctx.Input("Bias")->type()),
+                      "Bias input should be of float type");
+    PADDLE_ENFORCE_EQ(bn_param_type,
+                      framework::ToDataType(ctx.Input("Mean")->type()),
+                      "Mean input should be of float type");
+    PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
+                                         ctx.Input("Variance")->type()),
+                      "Variance input should be of float type");
+    return framework::OpKernelType(input_data_type, ctx.GetPlace());
+  }
 };
 
 class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc
index 2d1556efc66826ea9847de8311ccecdee0ea7871..6ceacc39924a7558e380aaf563aaf234f1bf30a5 100644
--- a/paddle/fluid/operators/batch_norm_op.cu.cc
+++ b/paddle/fluid/operators/batch_norm_op.cu.cc
@@ -18,6 +18,7 @@ limitations under the License. */
 #include 
 #include "paddle/fluid/operators/math/math_function.h"
 #include "paddle/fluid/platform/cudnn_helper.h"
+#include "paddle/fluid/platform/float16.h"
 
 namespace paddle {
 namespace operators {
@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
 using DataLayout = framework::DataLayout;
 template 
 using CudnnDataType = platform::CudnnDataType;
+template 
+using BatchNormParamType = typename CudnnDataType::BatchNormParamType;
 
 void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
                   int *N, int *C, int *H, int *W, int *D) {
@@ -104,8 +107,9 @@ class BatchNormKernel
     CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
         data_desc_, CudnnDataType::type,
         x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
+    // Note: PERSISTENT not implemented for inference
     CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
-        bn_param_desc_, data_desc_, mode_));
+        bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
 
     const auto *scale = ctx.Input("Scale");
     const auto *bias = ctx.Input("Bias");
@@ -118,15 +122,16 @@ class BatchNormKernel
 
     // alloc memory
     y->mutable_data(ctx.GetPlace());
-    mean_out->mutable_data(ctx.GetPlace());
-    variance_out->mutable_data(ctx.GetPlace());
-    saved_mean->mutable_data(ctx.GetPlace());
-    saved_variance->mutable_data(ctx.GetPlace());
+    mean_out->mutable_data>(ctx.GetPlace());
+    variance_out->mutable_data>(ctx.GetPlace());
+    saved_mean->mutable_data>(ctx.GetPlace());
+    saved_variance->mutable_data>(ctx.GetPlace());
 
     auto &dev_ctx = ctx.template device_context();
-    math::SetConstant functor;
-    functor(dev_ctx, saved_mean, 0);
-    functor(dev_ctx, saved_variance, 0);
+    math::SetConstant>
+        functor;
+    functor(dev_ctx, saved_mean, static_cast>(0));
+    functor(dev_ctx, saved_variance, static_cast>(0));
 
     auto handle = dev_ctx.cudnn_handle();
 
@@ -147,8 +152,10 @@ class BatchNormKernel
           CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(),
           CudnnDataType::kZero(), data_desc_, x->template data(),
           data_desc_, y->template mutable_data(ctx.GetPlace()),
-          bn_param_desc_, scale->template data(), bias->template data(),
-          est_mean->template data(), est_var->template data(), epsilon));
+          bn_param_desc_, scale->template data>(),
+          bias->template data>(),
+          est_mean->template data>(),
+          est_var->template data>(), epsilon));
     } else {
       // Run training mode.
       // obtain running mean and running inv var, and see if we need to
@@ -159,11 +166,16 @@ class BatchNormKernel
           handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(),
           data_desc_, x->template data(), data_desc_,
           y->template mutable_data(ctx.GetPlace()), bn_param_desc_,
-          scale->template data(), bias->template data(), this_factor,
-          mean_out->template mutable_data(ctx.GetPlace()),
-          variance_out->template mutable_data(ctx.GetPlace()), epsilon,
-          saved_mean->template mutable_data(ctx.GetPlace()),
-          saved_variance->template mutable_data(ctx.GetPlace())));
+          scale->template data>(),
+          bias->template data>(), this_factor,
+          mean_out->template mutable_data>(
+              ctx.GetPlace()),
+          variance_out->template mutable_data>(
+              ctx.GetPlace()),
+          epsilon, saved_mean->template mutable_data>(
+                       ctx.GetPlace()),
+          saved_variance->template mutable_data>(
+              ctx.GetPlace())));
     }
 
     // clean when exit.
@@ -270,9 +282,9 @@ class BatchNormGradKernel
 }  // namespace paddle
 
 namespace ops = paddle::operators;
+namespace plat = paddle::platform;
 REGISTER_OP_CUDA_KERNEL(
-    batch_norm,
-    ops::BatchNormKernel);
+    batch_norm, ops::BatchNormKernel,
+    ops::BatchNormKernel);
 REGISTER_OP_CUDA_KERNEL(
-    batch_norm_grad,
-    ops::BatchNormGradKernel);
+    batch_norm_grad, ops::BatchNormGradKernel);
diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc
index 17e576a9d5c8f50fbe84b066a93460f03ae6bb08..299a0aed01dfe0448d896738d9fd33319b1b2887 100644
--- a/paddle/fluid/operators/math/math_function.cc
+++ b/paddle/fluid/operators/math/math_function.cc
@@ -278,6 +278,7 @@ void axpy(
   cblas_daxpy(n, alpha, x, 1, y, 1);
 }
 
+template struct SetConstant;
 template struct SetConstant;
 template struct SetConstant;
 template struct SetConstant;
diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu
index c6ca2693a053360ce5dc44765acf1520a11cce2c..1e909db5288afccb9dd0be08a45cf3c27048ae6f 100644
--- a/paddle/fluid/operators/math/math_function.cu
+++ b/paddle/fluid/operators/math/math_function.cu
@@ -348,6 +348,7 @@ void axpy(
                                                 &alpha, x, 1, y, 1));
 }
 
+template struct SetConstant;
 template struct SetConstant;
 template struct SetConstant;
 template struct SetConstant;
diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt
index 744bd3b7ef71f83ad82979eb966369c2e9456a7d..6fa0195b9ae103418beb56cc4b0fa9ab59e93108 100644
--- a/paddle/fluid/operators/reader/CMakeLists.txt
+++ b/paddle/fluid/operators/reader/CMakeLists.txt
@@ -15,10 +15,12 @@ function(reader_library TARGET_NAME)
         PARENT_SCOPE)
 endfunction()
 
+reader_library(open_files_op SRCS open_files_op.cc)
 reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
 reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
 reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
 reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
 reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
+reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
 # Export local libraries to parent
 set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
index bd0bb2ee3b0252f47318c59d9940d8dd478723de..76cdb794ccdb4a015ae8630940a5c26845e7a7b3 100644
--- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
+++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
 };
 
 void DoubleBufferReader::ReadNext(std::vector* out) {
+  if (!HasNext()) {
+    PADDLE_THROW("There is no next data!");
+  }
+
   if (local_buffer_.payloads_.empty()) {
     buffer_->Receive(&local_buffer_);
   }
-
   *out = local_buffer_.payloads_;
   local_buffer_.payloads_.clear();
   if (local_buffer_.ctx_) {
diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..4d4e9fb909eafea5328491a4097276577f28a5ba
--- /dev/null
+++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
@@ -0,0 +1,101 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/detail/safe_ref.h"
+#include "paddle/fluid/operators/reader/reader_op_registry.h"
+
+namespace paddle {
+namespace operators {
+namespace reader {
+
+class MultiPassReader : public framework::DecoratedReader {
+ public:
+  MultiPassReader(ReaderBase* reader, int pass_num)
+      : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
+
+  void ReadNext(std::vector* out) override {
+    if (!HasNext()) {
+      PADDLE_THROW("There is no next data!");
+    }
+    reader_->ReadNext(out);
+  }
+
+  bool HasNext() const override {
+    if (reader_->HasNext()) {
+      return true;
+    } else {
+      ++pass_count_;
+      if (pass_count_ >= pass_num_) {
+        return false;
+      } else {
+        reader_->ReInit();
+        return true;
+      }
+    }
+  }
+
+  void ReInit() override {
+    pass_count_ = 0;
+    reader_->ReInit();
+  }
+
+ private:
+  int pass_num_;
+  mutable int pass_count_;
+};
+
+class CreateMultiPassReaderOp : public framework::OperatorBase {
+ public:
+  using framework::OperatorBase::OperatorBase;
+
+ private:
+  void RunImpl(const framework::Scope& scope,
+               const platform::Place& dev_place) const override {
+    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
+                                        ->Get();
+    auto& out = detail::Ref(scope.FindVar(Output("Out")));
+    int pass_num = Attr("pass_num");
+    out.GetMutable()->Reset(
+        new MultiPassReader(underlying_reader.Get(), pass_num));
+  }
+};
+
+class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
+ public:
+  CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
+      : DecoratedReaderMakerBase(op_proto, op_checker) {
+    AddAttr("pass_num", "The number of pass to run.").GreaterThan(0);
+    AddComment(R"DOC(
+      CreateMultiPassReader Operator
+
+      This operator creates a multi-pass reader. A multi-pass reader 
+      is used to yield data for several pass training continuously. 
+      It takes the the number of pass to run as one of its attributes
+      ('pass_num'), and maintains a pass counter to record how many 
+      passes it has completed. When the underlying reader reach the EOF, 
+      the multi-pass reader checks whether it has completed training 
+      of the given number of pass. If not, the underlying reader will 
+      be re-initialized and starts a new pass automatically.
+    )DOC");
+  }
+};
+
+}  // namespace reader
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators::reader;
+REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
+                                   ops::CreateMultiPassReaderOp,
+                                   ops::CreateMultiPassReaderOpMaker);
diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..414c76fea0bb916dfeafe38c0448a7a800889e03
--- /dev/null
+++ b/paddle/fluid/operators/reader/open_files_op.cc
@@ -0,0 +1,212 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/channel.h"
+#include "paddle/fluid/operators/reader/reader_op_registry.h"
+
+namespace paddle {
+namespace operators {
+namespace reader {
+
+class MultipleReader : public framework::ReaderBase {
+ public:
+  MultipleReader(const std::vector& file_names,
+                 const std::vector& dims, size_t thread_num)
+      : file_names_(file_names), dims_(dims) {
+    prefetchers_.resize(thread_num);
+    StartNewScheduler();
+  }
+
+  void ReadNext(std::vector* out) override;
+  bool HasNext() const override;
+  void ReInit() override;
+
+  ~MultipleReader() { EndScheduler(); }
+
+ private:
+  void StartNewScheduler();
+  void EndScheduler();
+  void ScheduleThreadFunc();
+  void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
+
+  std::vector file_names_;
+  std::vector dims_;
+  std::thread scheduler_;
+  std::vector prefetchers_;
+  framework::Channel* waiting_file_idx_;
+  framework::Channel* available_thread_idx_;
+  framework::Channel>* buffer_;
+  mutable std::vector local_buffer_;
+};
+
+void MultipleReader::ReadNext(std::vector* out) {
+  if (!HasNext()) {
+    PADDLE_THROW("There is no next data!");
+  }
+
+  if (local_buffer_.empty()) {
+    buffer_->Receive(&local_buffer_);
+  }
+  *out = local_buffer_;
+  local_buffer_.clear();
+}
+
+bool MultipleReader::HasNext() const {
+  return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
+}
+
+void MultipleReader::ReInit() {
+  EndScheduler();
+  local_buffer_.clear();
+  StartNewScheduler();
+}
+
+void MultipleReader::StartNewScheduler() {
+  size_t thread_num = prefetchers_.size();
+  waiting_file_idx_ = framework::MakeChannel(file_names_.size());
+  available_thread_idx_ = framework::MakeChannel(thread_num);
+  buffer_ =
+      framework::MakeChannel>(thread_num);
+
+  for (size_t i = 0; i < file_names_.size(); ++i) {
+    waiting_file_idx_->Send(&i);
+  }
+  waiting_file_idx_->Close();
+  for (size_t i = 0; i < thread_num; ++i) {
+    available_thread_idx_->Send(&i);
+  }
+
+  scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
+}
+
+void MultipleReader::EndScheduler() {
+  available_thread_idx_->Close();
+  buffer_->Close();
+  waiting_file_idx_->Close();
+  if (scheduler_.joinable()) {
+    scheduler_.join();
+  }
+  delete buffer_;
+  delete available_thread_idx_;
+  delete waiting_file_idx_;
+}
+
+void MultipleReader::ScheduleThreadFunc() {
+  VLOG(5) << "MultipleReader schedule thread starts.";
+  size_t completed_thread_num = 0;
+  size_t thread_idx;
+  while (available_thread_idx_->Receive(&thread_idx)) {
+    std::thread& prefetcher = prefetchers_[thread_idx];
+    if (prefetcher.joinable()) {
+      prefetcher.join();
+    }
+    size_t file_idx;
+    if (waiting_file_idx_->Receive(&file_idx)) {
+      // Still have files to read. Start a new prefetch thread.
+      std::string file_name = file_names_[file_idx];
+      prefetcher = std::thread([this, file_name, thread_idx] {
+        PrefetchThreadFunc(file_name, thread_idx);
+      });
+    } else {
+      // No more file to read.
+      ++completed_thread_num;
+      if (completed_thread_num == prefetchers_.size()) {
+        buffer_->Close();
+        break;
+      }
+    }
+  }
+  // If users invoke ReInit() when scheduler is running, it will close the
+  // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
+  // to release their resource. So a check is needed before scheduler ends.
+  for (auto& p : prefetchers_) {
+    if (p.joinable()) {
+      p.join();
+    }
+  }
+  VLOG(5) << "MultipleReader schedule thread terminates.";
+}
+
+void MultipleReader::PrefetchThreadFunc(std::string file_name,
+                                        size_t thread_idx) {
+  VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
+  std::unique_ptr reader =
+      CreateReaderByFileName(file_name, dims_);
+  while (reader->HasNext()) {
+    std::vector ins;
+    reader->ReadNext(&ins);
+    if (!buffer_->Send(&ins)) {
+      VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
+                 "thread of file '"
+              << file_name << "' will terminate.";
+      break;
+    }
+  }
+  if (!available_thread_idx_->Send(&thread_idx)) {
+    VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
+               "Fail to send thread_idx.";
+  }
+  VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
+}
+
+class OpenFilesOp : public framework::OperatorBase {
+ public:
+  using framework::OperatorBase::OperatorBase;
+
+ private:
+  void RunImpl(const framework::Scope& scope,
+               const platform::Place& dev_place) const override {
+    const auto& shape_concat = Attr>("shape_concat");
+    const auto& ranks = Attr>("ranks");
+    PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
+    PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
+                      int(shape_concat.size()),
+                      "The accumulate of all ranks should be equal to the "
+                      "shape concat's length.");
+    const auto& file_names = Attr>("file_names");
+    PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
+    const size_t thread_num = Attr("thread_num");
+
+    auto* out = scope.FindVar(Output("Out"))
+                    ->template GetMutable();
+    out->Reset(new MultipleReader(
+        file_names, RestoreShapes(shape_concat, ranks), thread_num));
+  }
+};
+
+class OpenFilesOpMaker : public FileReaderMakerBase {
+ public:
+  OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
+      : FileReaderMakerBase(op_proto, op_checker) {
+    AddAttr>("file_names", "Files to be read.");
+    AddAttr("thread_num", "The maximal concurrent prefetch thread number.")
+        .GreaterThan(0);
+
+    AddComment(R"DOC(
+      OpenFiles Operator
+
+      An OpenFilesOp creates a MultipleReader, which is able to 
+      read data multi-threaded from multiple files.
+    )DOC");
+  }
+};
+
+}  // namespace reader
+}  // namespace operators
+}  // namespace paddle
+
+namespace reader = paddle::operators::reader;
+
+REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
+                              reader::OpenFilesOpMaker);
diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc
index 0ba4f3854431742eb354f8c90eb395f5d7b32b2e..fc8dc747ff0c2286f4516d8350f75d9887361924 100644
--- a/paddle/fluid/operators/reader/reader_op_registry.cc
+++ b/paddle/fluid/operators/reader/reader_op_registry.cc
@@ -36,6 +36,21 @@ std::unordered_map& FileReaderRegistry() {
   return regs;
 }
 
+std::unique_ptr CreateReaderByFileName(
+    const std::string& file_name, const std::vector& dims) {
+  size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
+  PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
+                    "File name illegal! A legal file name should be like: "
+                    "[file_name].[file_format] (e.g., 'data_file.recordio').");
+  std::string filetype = file_name.substr(separator_pos + 1);
+
+  auto itor = FileReaderRegistry().find(filetype);
+  PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
+                 "No file reader registered for '%s' format.", filetype);
+  framework::ReaderBase* reader = (itor->second)(file_name, dims);
+  return std::unique_ptr(reader);
+}
+
 FileReaderMakerBase::FileReaderMakerBase(
     framework::OpProtoAndCheckerMaker::OpProto* op_proto,
     framework::OpAttrChecker* op_checker)
diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h
index 58f9b4ba35546571fd3b1d0c3ce128f18e248f01..929d32ad8b367865e33530f8517343c513ee9878 100644
--- a/paddle/fluid/operators/reader/reader_op_registry.h
+++ b/paddle/fluid/operators/reader/reader_op_registry.h
@@ -21,6 +21,8 @@ namespace paddle {
 namespace operators {
 namespace reader {
 
+static constexpr char kFileFormatSeparator[] = ".";
+
 using FileReaderCreator = std::function&)>;
 
@@ -29,12 +31,15 @@ std::unordered_map& FileReaderRegistry();
 template 
 int RegisterFileReader(const std::string& filetype) {
   FileReaderRegistry()[filetype] = [](
-      const std::string& fn, const std::vector& dim) {
-    return new Reader(fn, dim);
+      const std::string& fn, const std::vector& dims) {
+    return new Reader(fn, dims);
   };
   return 0;
 }
 
+std::unique_ptr CreateReaderByFileName(
+    const std::string& file_name, const std::vector& dims);
+
 extern std::vector RestoreShapes(
     const std::vector& shape_concat, const std::vector& ranks);
 
diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h
index 7e001ecc56173db76e8c576e7efd66f41192f292..7c604e14eb245232ed92f53a00b9bde45c2fbaec 100644
--- a/paddle/fluid/platform/cudnn_helper.h
+++ b/paddle/fluid/platform/cudnn_helper.h
@@ -86,7 +86,8 @@ class CudnnDataType {
  public:
   static const cudnnDataType_t type = CUDNN_DATA_HALF;
   // The scaling param type is float for HALF and FLOAT tensors
-  typedef const float ScalingParamType;
+  using ScalingParamType = const float;
+  using BatchNormParamType = float;
   static ScalingParamType* kOne() {
     static ScalingParamType v = 1.0;
     return &v;
@@ -101,7 +102,8 @@ template <>
 class CudnnDataType {
  public:
   static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
-  typedef const float ScalingParamType;
+  using ScalingParamType = const float;
+  using BatchNormParamType = float;
   static ScalingParamType* kOne() {
     static ScalingParamType v = 1.0;
     return &v;
@@ -116,7 +118,8 @@ template <>
 class CudnnDataType {
  public:
   static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
-  typedef const double ScalingParamType;
+  using ScalingParamType = const double;
+  using BatchNormParamType = double;
   static ScalingParamType* kOne() {
     static ScalingParamType v = 1.0;
     return &v;
diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc
index e50de15b7c2b480357f5f6c7daa2b4a676749679..ed09d58f6a3e2dba50bf4407c0463480575b248e 100644
--- a/paddle/fluid/recordio/header.cc
+++ b/paddle/fluid/recordio/header.cc
@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
 
 bool Header::Parse(std::istream& is) {
   uint32_t magic;
-  size_t read_size =
-      is.readsome(reinterpret_cast(&magic), sizeof(uint32_t));
+  is.read(reinterpret_cast(&magic), sizeof(uint32_t));
+  size_t read_size = is.gcount();
   if (read_size < sizeof(uint32_t)) {
     return false;
   }
diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc
index d842f8fe5a4c9d1a2b564c738d97fffb02f3ccb5..c22281dc97e05173ad76ce76959833b92f11c4ee 100644
--- a/paddle/fluid/recordio/scanner.cc
+++ b/paddle/fluid/recordio/scanner.cc
@@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) {
 }
 
 void Scanner::Reset() {
+  stream_->clear();
   stream_->seekg(0, std::ios::beg);
   ParseNextChunk();
 }
diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py
index 9c91f395e7c9d7ca76c1a5cc310bc3bbc06daec9..bc5e291ad811315ddc9d101853d69c7f5ab5082d 100644
--- a/python/paddle/fluid/layers/io.py
+++ b/python/paddle/fluid/layers/io.py
@@ -21,7 +21,8 @@ from ..executor import global_scope
 
 __all__ = [
     'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
-    'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
+    'open_files', 'read_file', 'create_shuffle_reader',
+    'create_double_buffer_reader', 'create_multi_pass_reader'
 ]
 
 
@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
                              startup_var)
 
 
+def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
+    dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
+    shape_concat = []
+    ranks = []
+
+    for shape in shapes:
+        shape_concat.extend(shape)
+        ranks.append(len(shape))
+
+    var_name = unique_name('multiple_reader')
+
+    startup_blk = default_startup_program().current_block()
+    startup_var = startup_blk.create_var(name=var_name)
+    startup_blk.append_op(
+        type='open_files',
+        outputs={'Out': [startup_var]},
+        attrs={
+            'shape_concat': shape_concat,
+            'lod_levels': lod_levels,
+            'ranks': ranks,
+            'file_names': filenames,
+            'thread_num': thread_num
+        })
+
+    startup_var.desc.set_dtypes(dtypes)
+    startup_var.persistable = True
+    return _copy_reader_var_(default_main_program().current_block(),
+                             startup_var)
+
+
 def __create_decorated_reader__(op_type, reader, attrs):
     var_name = unique_name(op_type)
     startup_blk = default_startup_program().current_block()
@@ -314,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
                                        attrs)
 
 
+def create_multi_pass_reader(reader, pass_num):
+    return __create_decorated_reader__('create_multi_pass_reader', reader,
+                                       {'pass_num': int(pass_num)})
+
+
 def read_file(file_obj):
     helper = LayerHelper('read_file')
     out = [
diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore
index 6b3fc2a83c649c28d21c9a8a0b35c2f2fa04f269..ad02bdecf436bba925e2e3b7efb20c878df70dfd 100644
--- a/python/paddle/fluid/tests/unittests/.gitignore
+++ b/python/paddle/fluid/tests/unittests/.gitignore
@@ -1 +1,4 @@
 mnist.recordio
+mnist_0.recordio
+mnist_1.recordio
+mnist_2.recordio
diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
index 80e6fa6df3c21aa19feb571916f11c41ccd6bb10..10aa63e18a6eeaa44e5b12f7532998dca2bc5e9f 100644
--- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
     return backward_op
 
 
+def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
+    x_shape = x.shape
+    if len(x_shape) == 2:
+        if data_format == "NCHW":
+            x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
+        else:
+            x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
+
+    if data_format == "NCHW":
+        n, c, h, w = x.shape
+        mean_tile = np.reshape(mean, (1, c, 1, 1))
+        mean_tile = np.tile(mean_tile, (n, 1, h, w))
+        var_tile = np.reshape(var, (1, c, 1, 1))
+        var_tile = np.tile(var_tile, (n, 1, h, w))
+        normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
+        scale_tile = np.reshape(scale, (1, c, 1, 1))
+        scale_tile = np.tile(scale_tile, (n, 1, h, w))
+        offset_tile = np.reshape(offset, (1, c, 1, 1))
+        offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
+        y = normalized * scale_tile + offset_tile
+    elif data_format == "NHWC":
+        normalized = (x - mean) / np.sqrt(var + epsilon)
+        y = normalized * scale + offset
+    else:
+        raise ValueError("Unknown data order.")
+
+    if len(x_shape) == 2:
+        y = np.reshape(y, x_shape)
+    return y
+
+
 def _reference_training(x, scale, offset, epsilon, data_format):
     x_shape = x.shape
     if len(x_shape) == 2:
@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
         __set_tensor__(output, data)
 
 
-class TestBatchNormOp(OpTest):
+class TestBatchNormOpInference(OpTest):
+    def setUp(self):
+        self.dtype = np.float32
+
     def __assert_close(self, tensor, np_array, msg, atol=1e-4):
         self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
 
-    def test_python(self):
+    def check_with_place(self, place, data_layout, dtype, shape):
+        epsilon = 0.00001
+        if len(shape) == 2:
+            x_shape = shape
+            c = x_shape[1]
+        else:
+            n, h, w, c = shape[0], shape[1], shape[2], shape[3]
+            if data_layout == "NHWC":
+                x_shape = [n, h, w, c]
+            elif data_layout == "NCHW":
+                x_shape = [n, c, h, w]
+            else:
+                raise ValueError("Unknown data layout.")
+        scale_shape = [c]
+
+        x_val = np.random.random_sample(x_shape).astype(dtype)
+        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
+        bias_val = np.random.random_sample(scale_shape).astype(np.float32)
+
+        mean = np.zeros(scale_shape).astype(np.float32)
+        variance = np.ones(scale_shape).astype(np.float32)
+
+        y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
+                                   epsilon, data_layout).astype(dtype)
+
+        scope = core.Scope()
+
+        # create input
+        x_tensor = create_or_get_tensor(scope, "x_val",
+                                        OpTest.np_dtype_to_fluid_dtype(x_val),
+                                        place)
+        scale_tensor = create_or_get_tensor(
+            scope, "scale_val",
+            OpTest.np_dtype_to_fluid_dtype(scale_val), place)
+        bias_tensor = create_or_get_tensor(
+            scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
+        mean_tensor = create_or_get_tensor(scope, "mean",
+                                           OpTest.np_dtype_to_fluid_dtype(mean),
+                                           place)
+        variance_tensor = create_or_get_tensor(
+            scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
+
+        # create output
+        y_tensor = create_or_get_tensor(scope, "y_out", None, place)
+        saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
+                                                 place)
+        saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
+                                                     None, place)
+        mean_out_tensor = mean_tensor
+        variance_out_tensor = variance_tensor
+
+        batch_norm_op = Operator(
+            "batch_norm",
+            # inputs
+            X="x_val",
+            Scale="scale_val",
+            Bias="bias_val",
+            Mean="mean",
+            Variance="variance",
+            # outputs
+            Y="y_out",
+            MeanOut="mean",
+            VarianceOut="variance",
+            SavedMean="saved_mean",
+            SavedVariance="saved_variance",
+            # attrs
+            is_test=True,
+            data_layout=data_layout,
+            epsilon=epsilon)
+
+        batch_norm_op.run(scope, place)
+
+        # check inference result
+        self.__assert_close(
+            y_tensor,
+            y_out,
+            "inference output are different at " + str(place) + ", " +
+            data_layout + ", " + str(np.dtype(dtype)) +
+            str(np.array(y_tensor)) + str(y_out),
+            atol=1e-3)
+
+    def test_check_output(self):
+        places = [core.CPUPlace()]
+        if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
+            places.append(core.CUDAPlace(0))
+
+        for place in places:
+            for data_format in ["NCHW", "NHWC"]:
+                self.check_with_place(place, data_format, self.dtype,
+                                      [2, 3, 4, 5])
+                self.check_with_place(place, data_format, self.dtype, [2, 3])
+
+
+class TestFP16BatchNormOpInference(TestBatchNormOpInference):
+    def setUp(self):
+        self.dtype = np.float16
+
+    def test_check_output(self):
+        places = []
+        if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
+            place = core.CUDAPlace(0)
+            if core.is_float16_supported(place):
+                places.append(place)
+
+        for place in places:
+            for data_format in ["NCHW", "NHWC"]:
+                self.check_with_place(place, data_format, self.dtype,
+                                      [2, 3, 4, 5])
+                self.check_with_place(place, data_format, self.dtype, [2, 3])
+
+
+class TestBatchNormOpTraining(OpTest):
+    def __assert_close(self, tensor, np_array, msg, atol=1e-4):
+        self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
+
+    def test_python_testing(self):
+        data_format = "NHWC"
+        epsilon = 0.00001
+
+        n, h, w, c = 2, 3, 4, 5
+        x_shape = [n, h, w, c]
+        scale_shape = [c]
+
+        x_val = np.random.random_sample(x_shape).astype(np.float32)
+        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
+        bias_val = np.random.random_sample(scale_shape).astype(np.float32)
+
+        mean = np.zeros(scale_shape).astype(np.float32)
+        variance = np.ones(scale_shape).astype(np.float32)
+
+        y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
+                                   epsilon, "NHWC")
+
+        # running N, C, H, W case
+        # should produce the same results
+        x_shape2 = [n, c, h, w]
+        x_val2 = np.transpose(x_val, (0, 3, 1, 2))
+        y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
+                                    epsilon, "NCHW")
+
+        # transfer (N, C, H, W) back to (N, H, W, C)
+        y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
+        self.__assert_close(y_out, y_out2_trans, "inference output")
+        print 'python: NHWC, NCHW, inference checking passed'
+
+    def test_python_training(self):
         data_format = "NHWC"
         epsilon = 0.00001
         momentum = 0.9
@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
 
         # transfer (N, C, H, W) back to (N, H, W, C)
         y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
-        self.__assert_close(y_out, y_out2_trans, "batch variance")
+        self.__assert_close(y_out, y_out2_trans, "batch output")
         print 'python: NHWC, NCHW, forward checking passed'
 
         # test backward now
diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..8add353303e3626bbce68199a100306d4858766a
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
@@ -0,0 +1,65 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+import paddle.v2.dataset.mnist as mnist
+
+
+class TestMultipleReader(unittest.TestCase):
+    def setUp(self):
+        self.batch_size = 64
+        self.pass_num = 3
+        # Convert mnist to recordio file
+        with fluid.program_guard(fluid.Program(), fluid.Program()):
+            data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
+            feeder = fluid.DataFeeder(
+                feed_list=[
+                    fluid.layers.data(
+                        name='image', shape=[784]),
+                    fluid.layers.data(
+                        name='label', shape=[1], dtype='int64'),
+                ],
+                place=fluid.CPUPlace())
+            self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
+                './mnist.recordio', data_file, feeder)
+
+    def test_main(self):
+        with fluid.program_guard(fluid.Program(), fluid.Program()):
+            data_file = fluid.layers.open_recordio_file(
+                filename='./mnist.recordio',
+                shapes=[(-1, 784), (-1, 1)],
+                lod_levels=[0, 0],
+                dtypes=['float32', 'int64'])
+            data_file = fluid.layers.create_multi_pass_reader(
+                reader=data_file, pass_num=self.pass_num)
+            img, label = fluid.layers.read_file(data_file)
+
+            if fluid.core.is_compiled_with_cuda():
+                place = fluid.CUDAPlace(0)
+            else:
+                place = fluid.CPUPlace()
+
+            exe = fluid.Executor(place)
+            exe.run(fluid.default_startup_program())
+
+            batch_count = 0
+            while not data_file.eof():
+                img_val, = exe.run(fetch_list=[img])
+                batch_count += 1
+                self.assertLessEqual(img_val.shape[0], self.batch_size)
+            data_file.reset()
+            self.assertEqual(batch_count, self.num_batch * self.pass_num)
diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f8acf81efaba8fc0f3df4cfe3a42dc4e477df2
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py
@@ -0,0 +1,74 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+import paddle.v2.dataset.mnist as mnist
+from shutil import copyfile
+
+
+class TestMultipleReader(unittest.TestCase):
+    def setUp(self):
+        self.batch_size = 64
+        # Convert mnist to recordio file
+        with fluid.program_guard(fluid.Program(), fluid.Program()):
+            reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
+            feeder = fluid.DataFeeder(
+                feed_list=[  # order is image and label
+                    fluid.layers.data(
+                        name='image', shape=[784]),
+                    fluid.layers.data(
+                        name='label', shape=[1], dtype='int64'),
+                ],
+                place=fluid.CPUPlace())
+            self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
+                './mnist_0.recordio', reader, feeder)
+        copyfile('./mnist_0.recordio', './mnist_1.recordio')
+        copyfile('./mnist_0.recordio', './mnist_2.recordio')
+
+    def main(self, thread_num):
+        file_list = [
+            './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
+        ]
+        with fluid.program_guard(fluid.Program(), fluid.Program()):
+            data_files = fluid.layers.open_files(
+                filenames=file_list,
+                thread_num=thread_num,
+                shapes=[(-1, 784), (-1, 1)],
+                lod_levels=[0, 0],
+                dtypes=['float32', 'int64'])
+            img, label = fluid.layers.read_file(data_files)
+
+            if fluid.core.is_compiled_with_cuda():
+                place = fluid.CUDAPlace(0)
+            else:
+                place = fluid.CPUPlace()
+
+            exe = fluid.Executor(place)
+            exe.run(fluid.default_startup_program())
+
+            batch_count = 0
+            while not data_files.eof():
+                img_val, = exe.run(fetch_list=[img])
+                batch_count += 1
+                self.assertLessEqual(img_val.shape[0], self.batch_size)
+            data_files.reset()
+            self.assertEqual(batch_count, self.num_batch * 3)
+
+    def test_main(self):
+        self.main(thread_num=3)  # thread number equals to file number
+        self.main(thread_num=10)  # thread number is larger than file number
+        self.main(thread_num=2)  # thread number is less than file number