提交 83a06f69 编写于 作者: L Liu Yiqun

Merge branch 'develop' into core_inference_multi_thread

...@@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i ...@@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i
## RNN Algorithm Implementation ## RNN Algorithm Implementation
<p align="center"> <p align="center">
<img src="./images/rnn.jpg"/> <img src="./rnn.jpg"/>
</p> </p>
The above diagram shows an RNN unrolled into a full network. The above diagram shows an RNN unrolled into a full network.
...@@ -22,7 +22,7 @@ There are several important concepts here: ...@@ -22,7 +22,7 @@ There are several important concepts here:
There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step. There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step.
<p align="center"> <p align="center">
<img src="./images/rnn.png"/><br/> <img src="./rnn.png"/><br/>
Figure 2 illustrates the RNN's data flow Figure 2 illustrates the RNN's data flow
</p> </p>
...@@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable. ...@@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable.
### Usage in Python ### Usage in Python
For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/block.md).
We can define an RNN's step-net using a Block: We can define an RNN's step-net using a Block:
...@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par ...@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par
The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text. The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text.
<p align="center"> <p align="center">
<img src="./images/2_level_rnn.png"/> <img src="./2_level_rnn.png"/>
</p> </p>
```python ```python
...@@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st ...@@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st
<p align="center"> <p align="center">
<img src="images/rnn_2level_data.png"/> <img src="./rnn_2level_data.png"/>
</p> </p>
...@@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, const std::string& feed_holder_name,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name, bool create_vars) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
bool has_feed_ops = bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name); has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
...@@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
Run(*copy_program, scope, 0, true, true); Run(*copy_program, scope, 0, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder // obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) { for (auto* op : global_block->AllOps()) {
......
...@@ -54,7 +54,8 @@ class Executor { ...@@ -54,7 +54,8 @@ class Executor {
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id); const ProgramDesc& program, int block_id);
......
...@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <cfloat> #include <cfloat>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,6 +27,8 @@ using Tensor = framework::Tensor; ...@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) { int *N, int *C, int *H, int *W, int *D) {
...@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
// Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_)); bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
...@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory // alloc memory
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace()); mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace()); variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace()); saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace()); saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor(dev_ctx, saved_mean, 0); functor;
functor(dev_ctx, saved_variance, 0); functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
...@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(), CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(), CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()), data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(), bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon)); bias->template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(), epsilon));
} else { } else {
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and see if we need to // obtain running mean and running inv var, and see if we need to
...@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_, data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_, y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor, scale->template data<BatchNormParamType<T>>(),
mean_out->template mutable_data<T>(ctx.GetPlace()), bias->template data<BatchNormParamType<T>>(), this_factor,
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon, mean_out->template mutable_data<BatchNormParamType<T>>(
saved_mean->template mutable_data<T>(ctx.GetPlace()), ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace()))); variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
} }
// clean when exit. // clean when exit.
...@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm, batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>); ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>( ...@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
cblas_daxpy(n, alpha, x, 1, y, 1); cblas_daxpy(n, alpha, x, 1, y, 1);
} }
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, float>; template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>; template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>; template struct SetConstant<platform::CPUDeviceContext, int>;
......
...@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>( ...@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
&alpha, x, 1, y, 1)); &alpha, x, 1, y, 1));
} }
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>; template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>; template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>; template struct SetConstant<platform::CUDADeviceContext, int>;
......
...@@ -15,10 +15,12 @@ function(reader_library TARGET_NAME) ...@@ -15,10 +15,12 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.payloads_.empty()) { if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_); buffer_->Receive(&local_buffer_);
} }
*out = local_buffer_.payloads_; *out = local_buffer_.payloads_;
local_buffer_.payloads_.clear(); local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) { if (local_buffer_.ctx_) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(ReaderBase* reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
}
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
++pass_count_;
if (pass_count_ >= pass_num_) {
return false;
} else {
reader_->ReInit();
return true;
}
}
}
void ReInit() override {
pass_count_ = 0;
reader_->ReInit();
}
private:
int pass_num_;
mutable int pass_count_;
};
class CreateMultiPassReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset(
new MultiPassReader(underlying_reader.Get(), pass_num));
}
};
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC(
CreateMultiPassReader Operator
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
It takes the the number of pass to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reach the EOF,
the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
ops::CreateMultiPassReaderOp,
ops::CreateMultiPassReaderOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
public:
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}
bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}
void MultipleReader::ReInit() {
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}
class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};
class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
...@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { ...@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs; return regs;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase( FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
......
...@@ -21,6 +21,8 @@ namespace paddle { ...@@ -21,6 +21,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*( using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>; const std::string&, const std::vector<framework::DDim>&)>;
...@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); ...@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader> template <typename Reader>
int RegisterFileReader(const std::string& filetype) { int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = []( FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) { const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dim); return new Reader(fn, dims);
}; };
return 0; return 0;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
...@@ -86,7 +86,8 @@ class CudnnDataType<float16> { ...@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_HALF; static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors // The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType; using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
...@@ -101,7 +102,8 @@ template <> ...@@ -101,7 +102,8 @@ template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT; static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType; using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
...@@ -116,7 +118,8 @@ template <> ...@@ -116,7 +118,8 @@ template <>
class CudnnDataType<double> { class CudnnDataType<double> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType; using ScalingParamType = const double;
using BatchNormParamType = double;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
return &v; return &v;
......
...@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) ...@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
bool Header::Parse(std::istream& is) { bool Header::Parse(std::istream& is) {
uint32_t magic; uint32_t magic;
size_t read_size = is.read(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t)); size_t read_size = is.gcount();
if (read_size < sizeof(uint32_t)) { if (read_size < sizeof(uint32_t)) {
return false; return false;
} }
......
...@@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) { ...@@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) {
} }
void Scanner::Reset() { void Scanner::Reset() {
stream_->clear();
stream_->seekg(0, std::ios::beg); stream_->seekg(0, std::ios::beg);
ParseNextChunk(); ParseNextChunk();
} }
......
...@@ -21,7 +21,8 @@ from ..executor import global_scope ...@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' 'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
] ]
...@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var) startup_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def __create_decorated_reader__(op_type, reader, attrs): def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
...@@ -314,6 +345,11 @@ def create_double_buffer_reader(reader, place=None): ...@@ -314,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
attrs) attrs)
def create_multi_pass_reader(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
def read_file(file_obj): def read_file(file_obj):
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
......
mnist.recordio mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
...@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op return backward_op
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
if data_format == "NCHW":
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
if data_format == "NCHW":
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (1, c, 1, 1))
mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
elif data_format == "NHWC":
normalized = (x - mean) / np.sqrt(var + epsilon)
y = normalized * scale + offset
else:
raise ValueError("Unknown data order.")
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape x_shape = x.shape
if len(x_shape) == 2: if len(x_shape) == 2:
...@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
__set_tensor__(output, data) __set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOpInference(OpTest):
def setUp(self):
self.dtype = np.float32
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python(self): def check_with_place(self, place, data_layout, dtype, shape):
epsilon = 0.00001
if len(shape) == 2:
x_shape = shape
c = x_shape[1]
else:
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
if data_layout == "NHWC":
x_shape = [n, h, w, c]
elif data_layout == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data layout.")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype)
scope = core.Scope()
# create input
x_tensor = create_or_get_tensor(scope, "x_val",
OpTest.np_dtype_to_fluid_dtype(x_val),
place)
scale_tensor = create_or_get_tensor(
scope, "scale_val",
OpTest.np_dtype_to_fluid_dtype(scale_val), place)
bias_tensor = create_or_get_tensor(
scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
mean_tensor = create_or_get_tensor(scope, "mean",
OpTest.np_dtype_to_fluid_dtype(mean),
place)
variance_tensor = create_or_get_tensor(
scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
# create output
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
place)
saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
None, place)
mean_out_tensor = mean_tensor
variance_out_tensor = variance_tensor
batch_norm_op = Operator(
"batch_norm",
# inputs
X="x_val",
Scale="scale_val",
Bias="bias_val",
Mean="mean",
Variance="variance",
# outputs
Y="y_out",
MeanOut="mean",
VarianceOut="variance",
SavedMean="saved_mean",
SavedVariance="saved_variance",
# attrs
is_test=True,
data_layout=data_layout,
epsilon=epsilon)
batch_norm_op.run(scope, place)
# check inference result
self.__assert_close(
y_tensor,
y_out,
"inference output are different at " + str(place) + ", " +
data_layout + ", " + str(np.dtype(dtype)) +
str(np.array(y_tensor)) + str(y_out),
atol=1e-3)
def test_check_output(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.float16
def test_check_output(self):
places = []
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
places.append(place)
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python_testing(self):
data_format = "NHWC"
epsilon = 0.00001
n, h, w, c = 2, 3, 4, 5
x_shape = [n, h, w, c]
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, "NHWC")
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
epsilon, "NCHW")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "inference output")
print 'python: NHWC, NCHW, inference checking passed'
def test_python_training(self):
data_format = "NHWC" data_format = "NHWC"
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
...@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest): ...@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
# transfer (N, C, H, W) back to (N, H, W, C) # transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "batch variance") self.__assert_close(y_out, y_out2_trans, "batch output")
print 'python: NHWC, NCHW, forward checking passed' print 'python: NHWC, NCHW, forward checking passed'
# test backward now # test backward now
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
self.pass_num = 3
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', data_file, feeder)
def test_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader(
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_file.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
self.assertEqual(batch_count, self.num_batch * self.pass_num)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册