提交 4eaa7897 编写于 作者: K Kexin Zhao

resolve conflict

# select_op Design
## Introduction
In golang, the [**select**](https://golang.org/ref/spec#Select_statements)
statement lets a goroutine wait on multiple communication operations at the
same time. The **select** blocks until one of its cases can run, then
executes the case. If multiple cases are ready to run, then one case is
choosen at random to be executed.
With the introduction of CSP for Paddle, we mimic this behavior by
creating a ***select_op***.
## How to use it
The **select_op** is available as a c++ operator. However most users
will prefer to use the much simplier Python API.
- **fluid.Select()**: Creates a select operator and adds it to the current
block within the main program. Also creates a sub block and adds it to the
main program. This sub block is used to hold all variables and operators
used by the case statements.
Within the select block, users can add cases by
calling **select.case** or **select.default** method.
- **fluid.Select.case(channel_action, channel, result_variable)**: Represents
a fluid channel send/recv case. This method creates a SelectCase block
guard and adds it to the Select block. The arguments into this method tells
the select which channel operation to listen to.
- **fluid.Select.default()**: Represents the fluid default case. This default
case is executed if none of the channel send/recv cases are available to
execute.
**Example:**
```
ch1 = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
quit_ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
x = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
y = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=1)
while_cond = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=True)
while_op = While(cond=while_cond)
with while_op.block():
with fluid.Select() as select:
with select.case(fluid.channel_send, channel, x):
# Send x, then perform Fibonacci calculation on x and y
x_tmp = fill_constant(shape=[1], dtype=core.VarDesc.VarType.INT32, value=0)
assign(input=x, output=x_tmp)
assign(input=y, output=x)
assign(elementwise_add(x=x_tmp, y=y), output=y)
with select.case(fluid.channel_recv, quit_channel, result2):
# Exit out of While loop
while_false = fill_constant(shape=[1], dtype=core.VarDesc.VarType.BOOL, value=False)
helper = layer_helper.LayerHelper('assign')
helper.append_op(
type='assign',
inputs={'X': [while_false]},
outputs={'Out': [while_cond]})
```
## How it Works
### Program Description
```
blocks {
idx: 0
...
// Create "case_to_execute" variable
ops {
outputs {
parameter: "Out"
arguments: "fill_constant_110.tmp_0"
}
type: "fill_constant"
attrs {
name: "force_cpu"
type: BOOLEAN
b: false
}
attrs {
name: "value"
type: FLOAT
f: -1.0
}
attrs {
name: "shape"
type: INTS
ints: 1
}
attrs {
name: "dtype"
type: INT
i: 2
}
}
// Create "select" operator.
// inputs:
// X: All input variables used by operators within the select block
// case_to_execute: Variable filled in by select_op when it determines
// which case to execute.
//
// outputs:
// Out: All output variables referenced by operators within select block.
//
// attrs:
// sub_block: The block id containing the select "cases"
// cases: Serialized list of all cases in the select op.
// Each case is serialized as: '<index>,<type>,<channel>,<value>'
// where type is 0 for default, 1 for send, and 2 for receive.
// No channel and values are needed for default cases.
ops {
inputs {
parameter: "X"
arguments: "fill_constant_103.tmp_0"
arguments: "fill_constant_104.tmp_0"
}
inputs {
parameter: "case_to_execute"
arguments: "fill_constant_110.tmp_0"
}
outputs {
parameter: "Out"
arguments: "fill_constant_110.tmp_0"
}
type: "select"
attrs {
name: "sub_block"
type: BLOCK
block_idx: 1
}
attrs {
name: "cases"
type: STRINGS
strings: "0,1,channel_101,fill_constant_109.tmp_0"
strings: "1,2,channel_102,fill_constant_108.tmp_0"
}
}
...
}
```
The python select API will add the **select_op** to the current block. In addition, it will
iterate through all it's case statements and add any input variables required by case statements
into **X**. It will also create a temp variable called **case_to_execute**. This variable is
filled in by the select_op after it has completed processing the case statements.
If there are no available cases to execute (ie: all cases are blocked on channel operations, and
there is no default statement), then the select_op will block the current thread. The thread will
unblock once there is a channel operation affecting one of the case statements, at which point, the
**select_op** will set the **case_to_execute** variable to the index of the case to execute.
Finally the select_op will call executor.run on the **sub_block**.
```
blocks {
idx: 1
parent_idx: 0
...
// Fill a tensor with the case index (ie: 0,1,2,3,ect.)
ops {
outputs {
parameter: "Out"
arguments: "fill_constant_111.tmp_0"
}
type: "fill_constant"
attrs {
name: "force_cpu"
type: BOOLEAN
b: false
}
attrs {
name: "value"
type: FLOAT
f: 0.0
}
attrs {
name: "shape"
type: INTS
ints: 1
}
attrs {
name: "dtype"
type: INT
i: 2
}
}
// Create an "equal" operator to compare the case index with the "case_to_execute"
// tensor (which was filled in by the select op).
ops {
inputs {
parameter: "X"
arguments: "fill_constant_111.tmp_0" // case 0
}
inputs {
parameter: "Y"
arguments: "fill_constant_110.tmp_0" // case_to_execute
}
outputs {
parameter: "Out"
arguments: "equal_0.tmp_0"
}
type: "equal"
attrs {
name: "axis"
type: INT
i: -1
}
}
// Use the output of the "equal" operator as a condition for the "conditional_block".
// If the condition evaluates to true, then execute the "sub_block" (which represents
// the select case's body)
ops {
inputs {
parameter: "Params"
}
inputs {
parameter: "X"
arguments: "equal_0.tmp_0"
}
outputs {
parameter: "Out"
}
outputs {
parameter: "Scope"
arguments: "_generated_var_0"
}
type: "conditional_block"
attrs {
name: "is_scalar_condition"
type: BOOLEAN
b: true
}
attrs {
name: "sub_block"
type: BLOCK
block_idx: 4
}
}
...
// Repeat the above operators for each case statements inside the select body
}
```
Cases are represented by a **conditional_block operator**, whose's condition is set as the output of
equal(**case_to_execute**, **case_index**). Since each case index is unique in this sub-block,
only one case will be executed.
### select_op flow
<p align="center">
<img src="./images/select_op_workflow.png"/><br/>
</p>
The select algorithm is inspired by golang's select routine. Please refer to
http://www.tapirgames.com/blog/golang-concurrent-select-implementation for more information.
## Backward Pass
TODO
......@@ -5,7 +5,7 @@ This document describes the RNN (Recurrent Neural Network) operator and how it i
## RNN Algorithm Implementation
<p align="center">
<img src="./images/rnn.jpg"/>
<img src="./rnn.jpg"/>
</p>
The above diagram shows an RNN unrolled into a full network.
......@@ -22,7 +22,7 @@ There are several important concepts here:
There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step.
<p align="center">
<img src="./images/rnn.png"/><br/>
<img src="./rnn.png"/><br/>
Figure 2 illustrates the RNN's data flow
</p>
......@@ -49,7 +49,7 @@ or copy the memory value of the previous step to the current ex-memory variable.
### Usage in Python
For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md).
For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/block.md).
We can define an RNN's step-net using a Block:
......@@ -93,7 +93,7 @@ For example, we could have a 2-level RNN, where the top level corresponds to par
The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text.
<p align="center">
<img src="./images/2_level_rnn.png"/>
<img src="./2_level_rnn.png"/>
</p>
```python
......@@ -149,5 +149,5 @@ If the `output_all_steps` is set to False, it will only output the final time st
<p align="center">
<img src="images/rnn_2level_data.png"/>
<img src="./rnn_2level_data.png"/>
</p>
# Fluid 分布式版本使用指南
本篇文章将说明如何在PaddlePaddle Fluid版本下进行分布式训练的配置和执行,以及将单机训练脚本改造成支持集群训练的版本
## 准备工作
* 可用的集群
包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。
* 安装PaddlePaddle Fluid with Distribution版本
所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。
**注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。
cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例:
``` bash
cmake .. -DWITH_DOC=OFF -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_SWIG_PY=ON -DWITH_PYTHON=ON
```
## 更新训练脚本
这里,我们以[Deep Learing 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)课程中的第一章 fit a line 为例,描述如何将单机训练脚本改造成支持集群训练的版本。
### 单机训练脚本示例
```python
import paddle.v2 as paddle
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 20
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
PASS_NUM = 100
for pass_id in range(PASS_NUM):
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader():
avg_loss_value, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_loss_value[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
exit(1)
```
我们创建了一个简单的全连接神经网络程序,并且通过Fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。
### 介绍Parameter Server
在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算、保存和优化任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为[Parameter Server](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md)
**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。**
### 分布式训练
Fliud专门提供了工具[Distributed Transpiler](https://github.com/PaddlePaddle/Paddle/blob/ba65d54d9d3b41cd3c5171b00f476d4e60133ddb/doc/fluid/design/dist_train/distributed_architecture.md#distributed-transpiler)用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recv 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。
```python
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
```
将Distributed Transpiler、优化算子和梯度函数放在一个代码中如下:
```python
... #define the program, cost, and create sgd optimizer
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) #get optimize OPs and gradient parameters
t = fluid.DistributeTranspiler() # create the transpiler instance
# slice the program into 2 pieces with optimizer_ops and gradient parameters list, as well as pserver_endpoints, which is a comma separated list of [IP:PORT] and number of trainers
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
... #create executor
# in pserver, run this
#current_endpoint here means current pserver IP:PORT you wish to run on
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
# in trainer, run this
... # define data reader
exe.run(fluid.default_startup_program())
for pass_id in range(100):
for data in train_reader():
exe.run(t.get_trainer_program())
```
### 分布式训练脚本运行说明
分布式任务的运行需要将表格中说明的多个参数进行赋值:
| 参数名 | 值类型 | 说明 | 示例 |
|:-------------|:------|:---------------------------------------|:-------------|
| trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 |
| pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 |
| trainers | int | 训练节点的总个数,>0的数字 | 4 |
| server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 |
| training_role | str | 节点角色, TRAINER/PSERVER | PSERVER |
**注意:** ```training_role```是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,样例如下:
```python
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops,
params_grads,
trainer_id,
pservers=pserver,
trainers=trainers)
if training_role == "PSERVER":
pserver_prog = t.get_pserver_program(server_endpoint)
pserver_startup = t.get_startup_program(server_endpoint, pserver_prog)
```
### Demo
完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。
第一步,进入demo代码所在目录:
```bash
cd /paddle/python/paddle/fluid/tests/book
```
第二步,启动Parameter Server:
```bash
PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py
```
执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。
第三步,启动Trainer:
```bash
PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py
```
由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。
现在我们就启动了一个包含一个Parameter Server和两个Trainer的分布式训练任务。
FAQ
====
This document provides answers to some of the frequently asked questions about PaddlePaddle. If you have a question that is not covered here, please go to `PaddlePaddle Community <https://github.com/PaddlePaddle/Paddle/issues>`_ , to find an answer or submit new `issue <https://github.com/PaddlePaddle/Paddle/issues/new>`_ , we will reply in time.
.. toctree::
:maxdepth: 1
......
......@@ -185,7 +185,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
const std::string& fetch_holder_name, bool create_vars) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
......@@ -255,7 +255,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
Run(*copy_program, scope, 0, true, true);
Run(*copy_program, scope, 0, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) {
......
......@@ -54,7 +54,8 @@ class Executor {
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
......
......@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <cfloat>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -26,6 +27,8 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) {
......@@ -104,8 +107,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
// Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
......@@ -118,15 +122,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, saved_mean, 0);
functor(dev_ctx, saved_variance, 0);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
auto handle = dev_ctx.cudnn_handle();
......@@ -147,8 +152,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
......@@ -159,11 +166,16 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor,
mean_out->template mutable_data<T>(ctx.GetPlace()),
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean->template mutable_data<T>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace())));
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
}
// clean when exit.
......@@ -270,9 +282,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
......@@ -78,7 +78,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int64_t index = i * class_num + label_data[i];
dx_data[index] = -dy_data[i] / x_data[index];
dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
}
}
}
......
......@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
......
......@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
&alpha, x, 1, y, 1));
}
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
......
......@@ -15,10 +15,12 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE)
endfunction()
reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
......@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(ReaderBase* reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
}
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
++pass_count_;
if (pass_count_ >= pass_num_) {
return false;
} else {
reader_->ReInit();
return true;
}
}
}
void ReInit() override {
pass_count_ = 0;
reader_->ReInit();
}
private:
int pass_num_;
mutable int pass_count_;
};
class CreateMultiPassReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset(
new MultiPassReader(underlying_reader.Get(), pass_num));
}
};
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC(
CreateMultiPassReader Operator
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
It takes the the number of pass to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reach the EOF,
the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
ops::CreateMultiPassReaderOp,
ops::CreateMultiPassReaderOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
public:
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}
bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}
void MultipleReader::ReInit() {
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}
class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};
class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
......@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
......
......@@ -21,6 +21,8 @@ namespace paddle {
namespace operators {
namespace reader {
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
......@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
};
return 0;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include <iostream>
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive;
using mkldnn::softmax_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE(input->dims().size() == 2UL,
"The input of softmax op must be a 2D matrix.");
const T* input_data = input->data<T>();
// allocate memory for output
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
// we will make normalization after final eg. axis: 1
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
"Softmax input and output dimensions should match");
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// TODO(jczaja-intel): support more formats
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
// create memory primitives
auto softmax_src_memory =
memory({softmax_md, mkldnn_engine}, (void*)input_data);
auto softmax_dst_memory =
memory({softmax_md, mkldnn_engine}, (void*)output_data);
auto softmax_prim_desc =
softmax_forward::primitive_desc(softmax_desc, mkldnn_engine);
auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory,
softmax_dst_memory);
std::vector<primitive> pipeline{softmax};
stream(stream::kind::eager).submit(pipeline).wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNKernel<float>);
......@@ -17,6 +17,9 @@ limitations under the License. */
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -47,13 +50,26 @@ class SoftmaxOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::StringToDataLayout(data_format),
library_);
}
};
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,6 +89,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Softmax Operator.
......
......@@ -86,7 +86,8 @@ class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
......@@ -101,7 +102,8 @@ template <>
class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
using ScalingParamType = const float;
using BatchNormParamType = float;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
......@@ -116,7 +118,8 @@ template <>
class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
using ScalingParamType = const double;
using BatchNormParamType = double;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
......
......@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
bool Header::Parse(std::istream& is) {
uint32_t magic;
size_t read_size =
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
is.read(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
size_t read_size = is.gcount();
if (read_size < sizeof(uint32_t)) {
return false;
}
......
......@@ -28,6 +28,7 @@ Scanner::Scanner(const std::string &filename) {
}
void Scanner::Reset() {
stream_->clear();
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
}
......
......@@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0):
return channel
def channel_send(channel, value, copy=False):
def channel_send(channel, value, is_copy=False):
"""
Sends a value through a channel variable. Used by an unbuffered or buffered
channel to pass data from within or to a concurrent Go block, where
......@@ -141,8 +141,8 @@ def channel_send(channel, value, copy=False):
channel (Variable|Channel): Channel variable created using
`make_channel`.
value (Variable): Value to send to channel
copy (bool): Copy data while channel send. If False, then data
is moved. The input cannot be used after move.
is_copy (bool): Copy data while channel send. If False, then data
is moved. The input cannot be used after move. (default False)
Returns:
Variable: The boolean status on whether or not the channel
successfully sent the passed value.
......@@ -166,7 +166,7 @@ def channel_send(channel, value, copy=False):
X = value
if copy is True:
if is_copy is True:
copied_X = helper.create_variable(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
......
......@@ -399,6 +399,9 @@ class LayerHelper(object):
if isinstance(act, basestring):
act = {'type': act}
tmp = self.create_tmp_variable(dtype=input_var.dtype)
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
act_type = act.pop('type')
self.append_op(
type=act_type,
......
......@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
]
......@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
......@@ -314,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
attrs)
def create_multi_pass_reader(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
......
......@@ -82,6 +82,7 @@ def fc(input,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
use_mkldnn=False,
act=None,
name=None):
"""
......@@ -163,8 +164,11 @@ def fc(input,
inputs={"X": input_var,
"Y": w},
outputs={"Out": tmp},
attrs={"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1})
attrs={
"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1,
'use_mkldnn': use_mkldnn
})
mul_results.append(tmp)
# sum
......@@ -1117,12 +1121,14 @@ def conv2d(input,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
use_mkldnn=False,
act=None):
act=None,
name=None):
"""
**Convlution2D Layer**
......@@ -1183,6 +1189,9 @@ def conv2d(input,
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
......@@ -1193,6 +1202,8 @@ def conv2d(input,
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -1233,6 +1244,7 @@ def conv2d(input,
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1262,6 +1274,7 @@ def conv2d(input,
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': use_mkldnn
......@@ -1670,7 +1683,9 @@ def conv2d_transpose(input,
stride=1,
dilation=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
name=None):
"""
**Convlution2D transpose layer**
......@@ -1739,8 +1754,10 @@ def conv2d_transpose(input,
dilation_H = dilation_W = dilation. Default: dilation = 1.
param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer.
Default: None
bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -1793,12 +1810,12 @@ def conv2d_transpose(input,
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
out = helper.create_tmp_variable(dtype=input.dtype)
pre_bias = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='conv2d_transpose',
inputs={'Input': [input],
'Filter': [img_filter]},
outputs={'Output': out},
outputs={'Output': pre_bias},
attrs={
'strides': stride,
'paddings': padding,
......@@ -1806,6 +1823,8 @@ def conv2d_transpose(input,
'use_cudnn': use_cudnn
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
out = helper.append_activation(pre_act)
return out
......
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
......@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
if data_format == "NCHW":
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
if data_format == "NCHW":
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (1, c, 1, 1))
mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
elif data_format == "NHWC":
normalized = (x - mean) / np.sqrt(var + epsilon)
y = normalized * scale + offset
else:
raise ValueError("Unknown data order.")
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y
def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
......@@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
__set_tensor__(output, data)
class TestBatchNormOp(OpTest):
class TestBatchNormOpInference(OpTest):
def setUp(self):
self.dtype = np.float32
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python(self):
def check_with_place(self, place, data_layout, dtype, shape):
epsilon = 0.00001
if len(shape) == 2:
x_shape = shape
c = x_shape[1]
else:
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
if data_layout == "NHWC":
x_shape = [n, h, w, c]
elif data_layout == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data layout.")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype)
scope = core.Scope()
# create input
x_tensor = create_or_get_tensor(scope, "x_val",
OpTest.np_dtype_to_fluid_dtype(x_val),
place)
scale_tensor = create_or_get_tensor(
scope, "scale_val",
OpTest.np_dtype_to_fluid_dtype(scale_val), place)
bias_tensor = create_or_get_tensor(
scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
mean_tensor = create_or_get_tensor(scope, "mean",
OpTest.np_dtype_to_fluid_dtype(mean),
place)
variance_tensor = create_or_get_tensor(
scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
# create output
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
place)
saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
None, place)
mean_out_tensor = mean_tensor
variance_out_tensor = variance_tensor
batch_norm_op = Operator(
"batch_norm",
# inputs
X="x_val",
Scale="scale_val",
Bias="bias_val",
Mean="mean",
Variance="variance",
# outputs
Y="y_out",
MeanOut="mean",
VarianceOut="variance",
SavedMean="saved_mean",
SavedVariance="saved_variance",
# attrs
is_test=True,
data_layout=data_layout,
epsilon=epsilon)
batch_norm_op.run(scope, place)
# check inference result
self.__assert_close(
y_tensor,
y_out,
"inference output are different at " + str(place) + ", " +
data_layout + ", " + str(np.dtype(dtype)) +
str(np.array(y_tensor)) + str(y_out),
atol=1e-3)
def test_check_output(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.float16
def test_check_output(self):
places = []
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
places.append(place)
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python_testing(self):
data_format = "NHWC"
epsilon = 0.00001
n, h, w, c = 2, 3, 4, 5
x_shape = [n, h, w, c]
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, "NHWC")
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
epsilon, "NCHW")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "inference output")
print 'python: NHWC, NCHW, inference checking passed'
def test_python_training(self):
data_format = "NHWC"
epsilon = 0.00001
momentum = 0.9
......@@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "batch variance")
self.__assert_close(y_out, y_out2_trans, "batch output")
print 'python: NHWC, NCHW, forward checking passed'
# test backward now
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
self.pass_num = 3
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', data_file, feeder)
def test_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader(
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_file.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
self.assertEqual(batch_count, self.num_batch * self.pass_num)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
......@@ -29,6 +29,7 @@ class TestSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.use_mkldnn = False
self.dtype = np.float32
self.init_kernel_type()
......@@ -36,7 +37,10 @@ class TestSoftmaxOp(OpTest):
out = np.apply_along_axis(stable_softmax, 1, x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {'use_cudnn': self.use_cudnn, }
self.attrs = {
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
def init_kernel_type(self):
pass
......@@ -76,5 +80,10 @@ class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
self.check_output_with_place(place, atol=1e-3)
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_mkldnn = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册