提交 809530f4 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_MultiEpochReader

# select_op Design
## Introduction
In golang, the [**select**](https://golang.org/ref/spec#Select_statements)
statement lets a goroutine wait on multiple communication operations at the
same time. The **select** blocks until one of its cases can run, then
executes the case. If multiple cases are ready to run, then one case is
choosen at random to be executed.
With the introduction of CSP for Paddle, we mimic this behavior by
creating a ***select_op***.
## How to use it
The **select_op** is available as a c++ operator. However most users
will prefer to use the much simplier Python API.
- **fluid.Select()**: Creates a select operator and adds it to the current
block within the main program. Also creates a sub block and adds it to the
main program. This sub block is used to hold all variables and operators
used by the case statements.
Within the select block, users can add cases by
calling **select.case** or **select.default** method.
- **fluid.Select.case(channel_action, channel, result_variable)**: Represents
a fluid channel send/recv case. This method creates a SelectCase block
guard and adds it to the Select block. The arguments into this method tells
the select which channel operation to listen to.
- **fluid.Select.default()**: Represents the fluid default case. This default
case is executed if none of the channel send/recv cases are available to
execute.
**Example:**
```
ch1 = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
quit_ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
x = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
y = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=1)
while_cond = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=True)
while_op = While(cond=while_cond)
with while_op.block():
with fluid.Select() as select:
with select.case(fluid.channel_send, channel, x):
# Send x, then perform Fibonacci calculation on x and y
x_tmp = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
assign(input=x, output=x_tmp)
assign(input=y, output=x)
assign(elementwise_add(x=x_tmp, y=y), output=y)
with select.case(fluid.channel_recv, quit_channel, result2):
# Exit out of While loop
while_false = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=False)
helper = layer_helper.LayerHelper('assign')
helper.append_op(
type='assign',
inputs={'X': [while_false]},
outputs={'Out': [while_cond]})
```
## How it Works
### Program Description
```
blocks {
idx: 0
...
// Create "case_to_execute" variable
ops {
outputs {
parameter: "Out"
arguments: "fill_constant_110.tmp_0"
}
type: "fill_constant"
attrs {
name: "force_cpu"
type: BOOLEAN
b: false
}
attrs {
name: "value"
type: FLOAT
f: -1.0
}
attrs {
name: "shape"
type: INTS
ints: 1
}
attrs {
name: "dtype"
type: INT
i: 2
}
}
// Create "select" operator.
// inputs:
// X: All input variables used by operators within the select block
// case_to_execute: Variable filled in by select_op when it determines
// which case to execute.
//
// outputs:
// Out: All output variables referenced by operators within select block.
//
// attrs:
// sub_block: The block id containing the select "cases"
// cases: Serialized list of all cases in the select op.
// Each case is serialized as: '<index>,<type>,<channel>,<value>'
// where type is 0 for default, 1 for send, and 2 for receive.
// No channel and values are needed for default cases.
ops {
inputs {
parameter: "X"
arguments: "fill_constant_103.tmp_0"
arguments: "fill_constant_104.tmp_0"
}
inputs {
parameter: "case_to_execute"
arguments: "fill_constant_110.tmp_0"
}
outputs {
parameter: "Out"
arguments: "fill_constant_110.tmp_0"
}
type: "select"
attrs {
name: "sub_block"
type: BLOCK
block_idx: 1
}
attrs {
name: "cases"
type: STRINGS
strings: "0,1,channel_101,fill_constant_109.tmp_0"
strings: "1,2,channel_102,fill_constant_108.tmp_0"
}
}
...
}
```
The python select API will add the **select_op** to the current block. In addition, it will
iterate through all it's case statements and add any input variables required by case statements
into **X**. It will also create a temp variable called **case_to_execute**. This variable is
filled in by the select_op after it has completed processing the case statements.
If there are no available cases to execute (ie: all cases are blocked on channel operations, and
there is no default statement), then the select_op will block the current thread. The thread will
unblock once there is a channel operation affecting one of the case statements, at which point, the
**select_op** will set the **case_to_execute** variable to the index of the case to execute.
Finally the select_op will call executor.run on the **sub_block**.
```
blocks {
idx: 1
parent_idx: 0
...
// Fill a tensor with the case index (ie: 0,1,2,3,ect.)
ops {
outputs {
parameter: "Out"
arguments: "fill_constant_111.tmp_0"
}
type: "fill_constant"
attrs {
name: "force_cpu"
type: BOOLEAN
b: false
}
attrs {
name: "value"
type: FLOAT
f: 0.0
}
attrs {
name: "shape"
type: INTS
ints: 1
}
attrs {
name: "dtype"
type: INT
i: 2
}
}
// Create an "equal" operator to compare the case index with the "case_to_execute"
// tensor (which was filled in by the select op).
ops {
inputs {
parameter: "X"
arguments: "fill_constant_111.tmp_0" // case 0
}
inputs {
parameter: "Y"
arguments: "fill_constant_110.tmp_0" // case_to_execute
}
outputs {
parameter: "Out"
arguments: "equal_0.tmp_0"
}
type: "equal"
attrs {
name: "axis"
type: INT
i: -1
}
}
// Use the output of the "equal" operator as a condition for the "conditional_block".
// If the condition evaluates to true, then execute the "sub_block" (which represents
// the select case's body)
ops {
inputs {
parameter: "Params"
}
inputs {
parameter: "X"
arguments: "equal_0.tmp_0"
}
outputs {
parameter: "Out"
}
outputs {
parameter: "Scope"
arguments: "_generated_var_0"
}
type: "conditional_block"
attrs {
name: "is_scalar_condition"
type: BOOLEAN
b: true
}
attrs {
name: "sub_block"
type: BLOCK
block_idx: 4
}
}
...
// Repeat the above operators for each case statements inside the select body
}
```
Cases are represented by a **conditional_block operator**, whose's condition is set as the output of
equal(**case_to_execute**, **case_index**). Since each case index is unique in this sub-block,
only one case will be executed.
### select_op flow
<p align="center">
<img src="./images/select_op_workflow.png"/><br/>
</p>
The select algorithm is inspired by golang's select routine. Please refer to
http://www.tapirgames.com/blog/golang-concurrent-select-implementation for more information.
## Backward Pass
TODO
GET STARTED
============
If you want to quickly know how to use PaddlePaddle, please refer to the following guide:
.. toctree::
:maxdepth: 1
quickstart_en.rst
While using PaddlePaddle to build applications, please understand some basic concepts.
Here is an example of linear regression. It introduces workflow of PaddlePaddle, including data format, model configuration and training, etc.
.. toctree::
:maxdepth: 1
concepts/use_concepts_en.rst
......@@ -2,6 +2,9 @@ Distributed Training
====================
The effectiveness of the deep learning model is often directly related to the scale of the data: it can generally achieve better results after increasing the size of the dataset on the same model. However, it can not fit in one single computer when the amount of data increases to a certain extent. At this point, using multiple computers for distributed training is a natural solution. In distributed training, the training data is divided into multiple copies (sharding), and multiple machines participating in the training read their own data for training and collaboratively update the parameters of the overall model.
Distributed training generally has framwork as shown below:
.. image:: src/ps_en.png
:width: 500
......
......@@ -14,12 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -40,14 +36,13 @@ namespace {
int kProgramId = -1;
} // namespace
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}
Executor::Executor(const platform::Place& place) : place_(place) {}
......@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
auto* ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}
// Check whether the block already has feed operators and feed_holder.
......@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return ctx;
return std::unique_ptr<ExecutorPrepareContext>(ctx);
}
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
......
......@@ -22,7 +22,16 @@ limitations under the License. */
namespace paddle {
namespace framework {
struct ExecutorPrepareContext;
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
......@@ -47,8 +56,8 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
......
......@@ -14,19 +14,20 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int64_t>);
elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>);
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
......@@ -89,6 +90,10 @@ class ListenAndServOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
int num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
......@@ -132,12 +137,36 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->ShutDown();
break;
}
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
// put optimize blocks in the thread pool to start run, the last block
// should be global ops.
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
// and this will still work.
std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async([&executor, &program, &recv_scope,
blkid]() {
int run_block = blkid; // thread local
try {
executor.Run(*program, &recv_scope, run_block,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait();
// Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) {
try {
executor.Run(*program, &recv_scope, num_blocks - 1,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
......
......@@ -15,6 +15,7 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE)
endfunction()
reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
......
......@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
public:
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}
bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}
void MultipleReader::ReInit() {
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}
class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};
class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
......@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
......
......@@ -21,6 +21,8 @@ namespace paddle {
namespace operators {
namespace reader {
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
......@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
};
return 0;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
......@@ -600,7 +600,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -616,7 +616,7 @@ HOST inline float16 operator+(const float16& a, const float16& b) {
return res;
}
HOST inline float16 operator-(const float16& a, const float16& b) {
inline float16 operator-(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -632,7 +632,7 @@ HOST inline float16 operator-(const float16& a, const float16& b) {
return res;
}
HOST inline float16 operator*(const float16& a, const float16& b) {
inline float16 operator*(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -648,7 +648,7 @@ HOST inline float16 operator*(const float16& a, const float16& b) {
return res;
}
HOST inline float16 operator/(const float16& a, const float16& b) {
inline float16 operator/(const float16& a, const float16& b) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -664,7 +664,7 @@ HOST inline float16 operator/(const float16& a, const float16& b) {
return res;
}
HOST inline float16 operator-(const float16& a) {
inline float16 operator-(const float16& a) {
float16 res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -679,27 +679,27 @@ HOST inline float16 operator-(const float16& a) {
return res;
}
HOST inline float16& operator+=(float16& a, const float16& b) {
inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
HOST inline float16& operator-=(float16& a, const float16& b) {
inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
HOST inline float16& operator*=(float16& a, const float16& b) {
inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
HOST inline float16& operator/=(float16& a, const float16& b) {
inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
HOST inline bool operator==(const float16& a, const float16& b) {
inline bool operator==(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -715,11 +715,9 @@ HOST inline bool operator==(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
HOST inline bool operator!=(const float16& a, const float16& b) {
return !(a == b);
}
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
HOST inline bool operator<(const float16& a, const float16& b) {
inline bool operator<(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n"
......@@ -735,7 +733,7 @@ HOST inline bool operator<(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
HOST inline bool operator<=(const float16& a, const float16& b) {
inline bool operator<=(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n"
......@@ -751,7 +749,7 @@ HOST inline bool operator<=(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
HOST inline bool operator>(const float16& a, const float16& b) {
inline bool operator>(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -767,7 +765,7 @@ HOST inline bool operator>(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
HOST inline bool operator>=(const float16& a, const float16& b) {
inline bool operator>=(const float16& a, const float16& b) {
uint16_t res;
asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n"
......@@ -785,69 +783,69 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
// Arithmetic operators for float16, software emulated on other CPU
#else
HOST inline float16 operator+(const float16& a, const float16& b) {
inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b));
}
HOST inline float16 operator-(const float16& a, const float16& b) {
inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}
HOST inline float16 operator*(const float16& a, const float16& b) {
inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}
HOST inline float16 operator/(const float16& a, const float16& b) {
inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}
HOST inline float16 operator-(const float16& a) {
inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOST inline float16& operator+=(float16& a, const float16& b) {
inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}
HOST inline float16& operator-=(float16& a, const float16& b) {
inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}
HOST inline float16& operator*=(float16& a, const float16& b) {
inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}
HOST inline float16& operator/=(float16& a, const float16& b) {
inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}
HOST inline bool operator==(const float16& a, const float16& b) {
inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}
HOST inline bool operator!=(const float16& a, const float16& b) {
inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}
HOST inline bool operator<(const float16& a, const float16& b) {
inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}
HOST inline bool operator<=(const float16& a, const float16& b) {
inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}
HOST inline bool operator>(const float16& a, const float16& b) {
inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}
HOST inline bool operator>=(const float16& a, const float16& b) {
inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
......
......@@ -307,15 +307,57 @@ class DistributeTranspiler:
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
# We try to put optimization program run parallelly, assume
# optimization program always looks like:
#
# prevop -> prevop -> opt op -> following op -> following op; ->
# prevop -> prevop -> opt op -> following op -> following op; ->
# global op -> global op
#
# we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
def __append_optimize_op__(op, block):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(block, op)
# append op to the current block
per_opt_block = optimize_block
for _, opt_op in enumerate(opt_op_on_pserver):
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \
op not in global_ops:
__append_optimize_op__(op, per_opt_block)
per_opt_block = pserver_program.create_block(0)
# append global ops
for glb_op in global_ops:
__append_optimize_op__(glb_op, per_opt_block)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
......@@ -660,10 +702,22 @@ class DistributeTranspiler:
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
op1_input_names = op1.desc.input_arg_names()
def _append_inname_remove_beta(varname_list):
op_input_names = []
for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names()
op2_input_names = op2.desc.input_arg_names()
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \
......
......@@ -235,6 +235,77 @@ class Executor(object):
tensor.set_lod(lod)
return tensor
def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None)
def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework
for op in program.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
return outs
def run(self,
program=None,
feed=None,
......@@ -268,7 +339,6 @@ class Executor(object):
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
......@@ -278,79 +348,30 @@ class Executor(object):
if scope is None:
scope = global_scope()
program_cache = None
program_cache_key = get_program_cache_key(feed, fetch_list)
cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
# find program cache by cache_key
program_cache = self.program_caches.get(program_cache_key, None)
# TODO(qiao): Should check program_cache and program are exactly the same.
cached_program = self._get_program_cache(cache_key)
if cached_program is None:
cached_program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
program = cached_program
else:
self.program_caches.pop(program_cache_key, None)
if program_cache is None:
program_cache = program.clone()
if use_program_cache:
self.program_caches[program_cache_key] = program_cache
global_block = program_cache.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy:
outs = as_numpy(outs)
return outs
......@@ -21,8 +21,8 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader',
'create_multi_pass_reader'
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
]
......@@ -288,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
......
......@@ -1117,12 +1117,14 @@ def conv2d(input,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
use_mkldnn=False,
act=None):
act=None,
name=None):
"""
**Convlution2D Layer**
......@@ -1183,6 +1185,9 @@ def conv2d(input,
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
......@@ -1193,6 +1198,8 @@ def conv2d(input,
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -1233,6 +1240,7 @@ def conv2d(input,
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1262,6 +1270,7 @@ def conv2d(input,
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': use_mkldnn
......@@ -1670,7 +1679,9 @@ def conv2d_transpose(input,
stride=1,
dilation=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
"""
**Convlution2D transpose layer**
......@@ -1739,8 +1750,10 @@ def conv2d_transpose(input,
dilation_H = dilation_W = dilation. Default: dilation = 1.
param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer.
Default: None
bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -1793,12 +1806,12 @@ def conv2d_transpose(input,
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
out = helper.create_tmp_variable(dtype=input.dtype)
pre_bias = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='conv2d_transpose',
inputs={'Input': [input],
'Filter': [img_filter]},
outputs={'Output': out},
outputs={'Output': pre_bias},
attrs={
'strides': stride,
'paddings': padding,
......@@ -1806,6 +1819,8 @@ def conv2d_transpose(input,
'use_cudnn': use_cudnn
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
out = helper.append_activation(pre_act)
return out
......
......@@ -664,6 +664,123 @@ class AdadeltaOptimizer(Optimizer):
return adadelta_op
class RMSPropOptimizer(Optimizer):
"""
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning
rate method. The original slides proposed RMSProp: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf .
The original equation is as follows:
.. math::
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 \\\\
w & = w - \\frac{\\eta} {\\sqrt{r(w,t) + \\epsilon}} \\nabla Q_{i}(w)
The first equation calculates moving average of the squared gradient for
each weight. Then dividing the gradient by :math: `sqrt{v(w,t)}`.
In some cases, adding a momentum term :math: `\\beta` is beneficial.
In our implementation, Nesterov momentum is used:
.. math::
r(w, t) & = \\rho r(w, t-1) + (1 - \\rho)(\\nabla Q_{i}(w))^2 \\\\
v(w, t) & = \\beta v(w, t-1) + \\frac{\\eta} {\\sqrt{v(w,t) +
\\epsilon}} \\nabla Q_{i}(w)
w & = w - v(w, t)
where, :math: `\\rho` is a hyperparameter and typical values are 0.9, 0.95
and so on. :math: `beta` is the momentum term. :math: `\\epsilon` is a
smoothing term to avoid division by zero, usually set somewhere in range
from 1e-4 to 1e-8.
Args:
learning_rate(float): global leraning rate.
rho(float): rho is :math: `\\rho` in equation, set 0.95 by default.
epsilon(float): :math: `\\epsilon` in equation is smoothing term to
avoid division by zero, set 1e-6 by default.
momentum(float): :math: `\\beta` in equation is the momentum term,
set 0.0 by default.
Raises:
ValueError: If learning_rate, rho, epsilon, momentum are None.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.RMSProp(0.0001)
_, params_grads = optimizer.minimize(cost)
"""
_momentum_acc_str = "momentum"
_mean_square_acc_str = "mean_square"
def __init__(self,
learning_rate,
rho=0.95,
epsilon=1.0e-6,
momentum=0.0,
**kwargs):
super(RMSPropOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
if learning_rate is None:
raise ValueError("learning_rate is not set.")
if rho is None:
raise ValueError("rho is not set.")
if epsilon is None:
raise ValueError("epsilon is not set.")
if momentum is None:
raise ValueError("momentum is not set.")
self.type = "rmsprop"
self._rho = rho
self._epsilon = epsilon
self._momentum = momentum
def _create_accumulators(self, block, parameters):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
for p in parameters:
self._add_accumulator(self._momentum_acc_str, p)
self._add_accumulator(self._mean_square_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
momentum_acc = self._get_accumulator(self._momentum_acc_str,
param_and_grad[0])
mean_square_acc = self._get_accumulator(self._mean_square_acc_str,
param_and_grad[0])
rmsprop_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Moment": momentum_acc,
"MeanSquare": mean_square_acc,
"LearningRate": self._create_param_lr(param_and_grad),
},
outputs={
"ParamOut": param_and_grad[0],
"MomentOut": momentum_acc,
"MeanSquareOut": mean_square_acc
},
attrs={
"epsilon": self._epsilon,
"decay": self._rho,
"momentum": self._momentum
})
return rmsprop_op
# We short the class name, since users will use the optimizer with the package
# name. The sample code:
#
......@@ -679,3 +796,4 @@ Adam = AdamOptimizer
Adamax = AdamaxOptimizer
DecayedAdagrad = DecayedAdagradOptimizer
Adadelta = AdadeltaOptimizer
RMSProp = RMSPropOptimizer
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
......@@ -13,158 +13,243 @@
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
class TestElementwiseOp(OpTest):
class TestElementwiseAddOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
self.dtype = np.float32
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_axis()
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestElementwiseAddOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
def init_dtype(self):
pass
def init_axis(self):
pass
class TestElementwiseAddOp_scalar2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
def init_dtype(self):
self.dtype = np.float16
class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.random((32, )).astype("float32"),
'Y': np.random.random((32, )).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestElementwiseAddOp_broadcast_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(2).astype(np.float32)
}
class TestElementwiseAddOp_scalar(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1)
}
class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 1)
}
class TestElementwiseAddOp_scalar2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(4).astype(np.float32)
}
class TestFP16ElementwiseAddOp_scalar2(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1, 4)
}
class TestElementwiseAddOp_Vector(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((32, )).astype(self.dtype)
self.y = np.random.random((32, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4, 1)
}
class TestFP16ElementwiseAddOp_Vector(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((32, )).astype(self.dtype)
self.y = np.random.random((32, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestElementwiseAddOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(2, 1).astype(np.float32)
}
class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2).astype(self.dtype)
self.out = self.x + self.y.reshape(2, 1, 1)
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1, 1)
}
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
class TestFP16ElementwiseAddOp_broadcast_0(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2).astype(self.dtype)
self.out = self.x + self.y.reshape(2, 1, 1)
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4)
}
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 1).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 1)
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1)
}
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_broadcast_1(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 4)
class TestFP16ElementwiseAddOp_broadcast_2(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 4)
class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 4, 1)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_broadcast_3(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 4, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(2, 1, 1, 1)
def init_axis(self):
self.axis = 0
class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(2, 1, 1, 1)
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 4)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 3, 4)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
if __name__ == '__main__':
......
......@@ -16,7 +16,6 @@ import unittest
import numpy
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册