Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
62d597c1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
62d597c1
编写于
9月 26, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:baidu/Paddle into feature/pybind_for_protobuf_desc
上级
16c5f629
367a54e0
变更
34
隐藏空白更改
内联
并排
Showing
34 changed file
with
873 addition
and
380 deletion
+873
-380
benchmark/paddle/image/run_mkldnn.sh
benchmark/paddle/image/run_mkldnn.sh
+1
-4
doc/design/refactor/distributed_architecture.md
doc/design/refactor/distributed_architecture.md
+222
-0
doc/design/refactor/parameter_server.md
doc/design/refactor/parameter_server.md
+0
-0
doc/design/refactor/src/compiler.graffle
doc/design/refactor/src/compiler.graffle
+0
-0
doc/design/refactor/src/compiler.png
doc/design/refactor/src/compiler.png
+0
-0
doc/design/refactor/src/dist-graph.graffle
doc/design/refactor/src/dist-graph.graffle
+0
-0
doc/design/refactor/src/dist-graph.png
doc/design/refactor/src/dist-graph.png
+0
-0
doc/design/refactor/src/distributed_architecture.graffle
doc/design/refactor/src/distributed_architecture.graffle
+0
-0
doc/design/refactor/src/distributed_architecture.png
doc/design/refactor/src/distributed_architecture.png
+0
-0
doc/design/refactor/src/local-graph.graffle
doc/design/refactor/src/local-graph.graffle
+0
-0
doc/design/refactor/src/local-graph.png
doc/design/refactor/src/local-graph.png
+0
-0
doc/design/refactor/src/local_architecture.graffle
doc/design/refactor/src/local_architecture.graffle
+0
-0
doc/design/refactor/src/local_architecture.png
doc/design/refactor/src/local_architecture.png
+0
-0
doc/design/refactor/src/paddle-compile.graffle
doc/design/refactor/src/paddle-compile.graffle
+0
-0
doc/design/refactor/src/paddle-compile.png
doc/design/refactor/src/paddle-compile.png
+0
-0
doc/faq/index_cn.rst
doc/faq/index_cn.rst
+91
-13
doc/getstarted/build_and_install/docker_install_cn.rst
doc/getstarted/build_and_install/docker_install_cn.rst
+1
-1
paddle/gserver/activations/MKLDNNActivation.cpp
paddle/gserver/activations/MKLDNNActivation.cpp
+187
-25
paddle/gserver/activations/MKLDNNActivation.h
paddle/gserver/activations/MKLDNNActivation.h
+35
-100
paddle/gserver/tests/test_MKLDNN.cpp
paddle/gserver/tests/test_MKLDNN.cpp
+2
-2
paddle/operators/accuracy_op.cu
paddle/operators/accuracy_op.cu
+6
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+51
-32
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+92
-55
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+56
-58
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+8
-3
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+34
-28
paddle/operators/multiplex_op.cu
paddle/operators/multiplex_op.cu
+23
-20
paddle/operators/multiplex_op.h
paddle/operators/multiplex_op.h
+13
-10
paddle/operators/top_k_op.cu
paddle/operators/top_k_op.cu
+6
-4
paddle/parameter/FirstOrderOptimizer.h
paddle/parameter/FirstOrderOptimizer.h
+11
-0
paddle/parameter/ParameterUpdateFunctions.cpp
paddle/parameter/ParameterUpdateFunctions.cpp
+3
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+1
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+23
-17
python/paddle/v2/framework/tests/test_multiplex_op.py
python/paddle/v2/framework/tests/test_multiplex_op.py
+7
-5
未找到文件。
benchmark/paddle/image/run_mkldnn.sh
浏览文件 @
62d597c1
...
@@ -9,11 +9,9 @@ function train() {
...
@@ -9,11 +9,9 @@ function train() {
bs
=
$2
bs
=
$2
use_mkldnn
=
$3
use_mkldnn
=
$3
if
[
$3
==
"True"
]
;
then
if
[
$3
==
"True"
]
;
then
use_mkldnn
=
$3
thread
=
1
thread
=
1
log
=
"logs/
${
topology
}
-mkldnn-
${
bs
}
.log"
log
=
"logs/
${
topology
}
-mkldnn-
${
bs
}
.log"
elif
[
$3
==
"False"
]
;
then
elif
[
$3
==
"False"
]
;
then
use_mkldnn
=
$3
thread
=
`
nproc
`
thread
=
`
nproc
`
log
=
"logs/
${
topology
}
-
${
thread
}
mklml-
${
bs
}
.log"
log
=
"logs/
${
topology
}
-
${
thread
}
mklml-
${
bs
}
.log"
else
else
...
@@ -39,8 +37,7 @@ if [ ! -d "logs" ]; then
...
@@ -39,8 +37,7 @@ if [ ! -d "logs" ]; then
mkdir
logs
mkdir
logs
fi
fi
#========= mkldnn =========#
#========== mkldnn ==========#
# vgg
train vgg 64 True
train vgg 64 True
train vgg 128 True
train vgg 128 True
train vgg 256 True
train vgg 256 True
...
...
doc/design/refactor/distributed_architecture.md
0 → 100644
浏览文件 @
62d597c1
# Design Doc: Distributed Training Architecture
## Abstract
PaddlePaddle v0.10.0 uses the "trainer-parameter server"
architecture. We run multiple replicated instances of trainers (runs
the same code written by the user) and parameter servers for
distributed training. This architecture served us well, but has some
limitations:
1.
Need to write special code to handle tasks which should only be run
by a single trainer. E.g., initializing model and saving model.
2.
Model parallelism is hard: need to write if-else branches conditioned
on the trainer ID to partition model onto each trainer, and manually
write the inter-model-shard communication code.
3.
The user can not directly specify the parameter update rule: need
to modify the parameter server C++ code and compile a new
binary. This adds complication for researchers: A lot of extra
effort is required. Besides, the training job submission program
may not allow running arbitrary binaries.
This design doc discusses PaddlePaddle's new distributed training
architecture that addresses the above limitations.
## Analysis
We will assume the user writes the trainer program by Python, the same
analysis holds if the trainer program is written in C++.
### Limitation 1
If we look at the Python code that the user writes, there are two
kinds of functionalities:
-
The training logic such as load / save model and print log.
-
The neural network definition such as the definition of the data
layer, the fully connected layer, the cost function and the
optimizer.
When we training with PaddlePaddle v0.10.0 distributedly, multiple
replicated Python instances are running on different nodes: both the
training logic and the neural network computation is replicated.
The tasks that should only run once all belong to the training logic,
if we only replicate the neural network computation, but do
**not**
replicate the training logic, the limitation could be solved.
### Limitation 2
Model parallelism means running a single model on multiple nodes by
partitioning the model onto different nodes and managing the
inter-model-shard communications.
PaddlePaddle should be able to modify the nerual network computation
definition to support model parallelism automatically. However, the
computation is only specified in Python code, and PaddlePaddle can not
modify Python code.
Just like compiler uses a intermediate representation (IR) so that
programmer does not need to manually optimize their code in most of
the cases - the compiler will optimize the IR:
<img
src=
"src/compiler.png"
/>
We can have our own IR too: PaddlePaddle can support model parallel by
converting the IR so the user no longer need to manually do it in
Python:
<img
src=
"src/paddle-compile.png"
/>
The IR for PaddlePaddle after refactor is called
`Block`
, it specifies
the computation dependency graph and the variables used in the
computation.
### Limitation 3
The user can not directly specify the parameter update rule for the
parameter server because the parameter server does not use the same
computation definition as the trainer. Instead, the update rule is
baked in the parameter server. The user can not specify the update
rule in the same way of specifying the trainer computation.
This could be fixed by making the parameter server run the same
computation definition as the trainer. For a detailed explanation,
please
see
[
Design Doc: Operation Graph Based Parameter Server
](
./dist_train.md
)
## Distributed Training Architecture
The new distributed training architecture can address the above
limitations. Below is the illustration:
<img
src=
"src/distributed_architecture.png"
/>
The architecture includes major components:
*PaddlePaddle Python*
,
*PaddlePaddle converter*
and
*PaddlePaddle runtime*
:
### PaddlePaddle Python
PaddlePaddle Python is the Python library that user's Python trainer
invoke to build the neural network topology, start training, etc.
```
Python
paddle.init()
input = paddle.op.recordIO("/home/data/mnist.recordio") # file stored on the cluster
img, label = input[0], input[1]
hidden = paddle.layer.fc(input=img, size=200, act=paddle.activation.Tanh())
prediction = paddle.layer.fc(input=img, size=10, act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=prediction, label=label)
optimizer = paddle.optimizer.SGD(cost, learning_rate=0.01)
session = paddle.session.NewRemote(num_trainer=3, num_ps=2, GPU_per_trainer=1)
for i in range(1000):
_, cost_val = session.eval(targets=[cost, optimizer])
print cost_val
```
The code above is a typical Python trainer code, the neural network
topology is built using helper functions such as
`paddle.layer.fc`
. The training is done by calling
`session.eval`
iteratively.
#### session.eval
As shown in the graph,
`session.eval`
sends the IR and the evaluation
inputs/targets to the PaddlePaddle cluster for evaluation. The
targets can be any variable in the computation graph. When the target
is the
`optimizer`
variable, the neural network will be optimized
once. When the target is the
`cost`
variable,
`session.eval`
returns
the cost value.
The Python
`session`
is a wrapper of the C++
`Session`
class. For more
information about
`Session`
, please
see
[
Design Doc: Session
](
./session.md
)
.
### PaddlePaddle Converter
PaddlePaddle converter automatically converts the IR in the request
(IR and evaluation inputs/targets) from PaddlePaddle Python to new
partitioned IRs and dispatch the new IRs and evaluation inputs/targets
to different PaddlePaddle runtimes. Below are the steps:
1.
Add
`feed`
OP that feeds the eval inputs, and
`fetch`
OP that
fetches the eval targets to the IR.
1.
Extract a new computation (sub)graph with
`feed`
and
`fetch`
OP as
the boundary. The runtime does not need to run the OP that is not
dependent by the
`fetch`
OP.
1.
Optimizes the computation graph.
1.
Place the OPs in the graph onto different devices on different
PaddlePaddle runtime according to a placement algorithm and device
constraint specified by the user.
1.
Partition the graph according to runtime boundaries and add
`send`
/
`recv`
OP pair on the runtime boundaries.
1.
Dispatch the partitioned graph to different PaddlePaddle runtimes.
1.
PaddlePaddle runtimes with the
`fetch`
OP reports evaluation
results back to the converter, the convert reports the evaluation
results back to the PaddlePaddle Python.
The output IRs will be cached to optimize the conversion latency.
#### Placement Algorithm
Our first implementation will only support "trainer-parameter server"
placement: the parameters, initializers, and optimizers are placed on
the PaddlePaddle runtimes with the parameter server role. And
everything else will be placed on the PaddlePaddle runtimes with the
trainer role. This has the same functionality of our
"trainer-parameter server" architecture of PaddlePaddle v0.10.0, but
is more general and flexible.
In the future, we will implement the general placement algorithm,
which makes placements according to the input IR, and a model of
device computation time and device communication time. Model
parallelism requires the general placement algorithm.
### PaddlePaddle Runtime
The PaddlePaddle runtime owns multiple devices (e.g., CPUs, GPUs) and
runs the IR. The runtime does not need to do OP placement since it's
already done by the converter.
### Local Training Architecture
The local training architecture will be the same as the distributed
training architecture, the differences are everything runs locally,
and there is just one PaddlePaddle runtime:
<img
src=
"src/local_architecture.png"
/>
### Training Data
In PaddlePaddle v0.10.0, training data is typically read
with
[
data reader
](
../reader/README.md
)
from Python. This approach is
no longer efficient when training distributedly since the Python
process no longer runs on the same node with the trainer processes,
the Python reader will need to read from the distributed filesystem
(assuming it has the access) and send to the trainers, doubling the
network traffic.
When doing distributed training, the user can still use Python data
reader: the training data are sent with
`session.eval`
. However should
be used for debugging purpose only. The users are encouraged to use
the read data OPs.
## References:
[
1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems
](
https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf
)
[
2] [TensorFlow: A System for Large-Scale Machine Learning
](
https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf
)
doc/design/
ops/dist_train
.md
→
doc/design/
refactor/parameter_server
.md
浏览文件 @
62d597c1
文件已移动
doc/design/refactor/src/compiler.graffle
0 → 100644
浏览文件 @
62d597c1
文件已添加
doc/design/refactor/src/compiler.png
0 → 100644
浏览文件 @
62d597c1
15.5 KB
doc/design/
ops
/src/dist-graph.graffle
→
doc/design/
refactor
/src/dist-graph.graffle
浏览文件 @
62d597c1
文件已移动
doc/design/
ops
/src/dist-graph.png
→
doc/design/
refactor
/src/dist-graph.png
浏览文件 @
62d597c1
文件已移动
doc/design/refactor/src/distributed_architecture.graffle
0 → 100644
浏览文件 @
62d597c1
文件已添加
doc/design/refactor/src/distributed_architecture.png
0 → 100644
浏览文件 @
62d597c1
46.5 KB
doc/design/
ops
/src/local-graph.graffle
→
doc/design/
refactor
/src/local-graph.graffle
浏览文件 @
62d597c1
文件已移动
doc/design/
ops
/src/local-graph.png
→
doc/design/
refactor
/src/local-graph.png
浏览文件 @
62d597c1
文件已移动
doc/design/refactor/src/local_architecture.graffle
0 → 100644
浏览文件 @
62d597c1
文件已添加
doc/design/refactor/src/local_architecture.png
0 → 100644
浏览文件 @
62d597c1
28.3 KB
doc/design/refactor/src/paddle-compile.graffle
0 → 100644
浏览文件 @
62d597c1
文件已添加
doc/design/refactor/src/paddle-compile.png
0 → 100644
浏览文件 @
62d597c1
19.7 KB
doc/faq/index_cn.rst
浏览文件 @
62d597c1
...
@@ -247,11 +247,11 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字
...
@@ -247,11 +247,11 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字
CMake Warning at cmake/version.cmake:20 (message):
CMake Warning at cmake/version.cmake:20 (message):
Cannot add paddle version from git tag
Cannot add paddle version from git tag
那么用户需要拉取所有的远程分支到本机,命令为 :code:`git fetch upstream`,然后重新cmake即可。
那么用户需要拉取所有的远程分支到本机,命令为 :code:`git fetch upstream`,然后重新cmake即可。
12. A protocol message was rejected because it was too big
12. A protocol message was rejected because it was too big
----------------------------------------------------------
----------------------------------------------------------
--
如果在训练NLP相关模型时,出现以下错误:
如果在训练NLP相关模型时,出现以下错误:
...
@@ -316,10 +316,42 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
...
@@ -316,10 +316,42 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异
* 模型一直不收敛,发散到了一个数值特别大的地方。
* 模型一直不收敛,发散到了一个数值特别大的地方。
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习率或者对数据进行归一化处理。
这里有两种有效的解决方法:
1. 设置 :code:`gradient_clipping_threshold` 参数,示例代码如下:
.. code-block:: python
optimizer = paddle.optimizer.RMSProp(
learning_rate=1e-3,
gradient_clipping_threshold=10.0,
regularization=paddle.optimizer.L2Regularization(rate=8e-4))
具体可以参考 `nmt_without_attention <https://github.com/PaddlePaddle/models/blob/develop/nmt_without_attention/train.py#L35>`_ 示例。
2. 设置 :code:`error_clipping_threshold` 参数,示例代码如下:
.. code-block:: python
decoder_inputs = paddle.layer.fc(
act=paddle.activation.Linear(),
size=decoder_size * 3,
bias_attr=False,
input=[context, current_word],
layer_attr=paddle.attr.ExtraLayerAttribute(
error_clipping_threshold=100.0))
完整代码可以参考示例 `machine translation <https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/train.py#L66>`_ 。
两种方法的区别:
1. 两者都是对梯度的截断,但截断时机不同,前者在 :code:`optimzier` 更新网络参数时应用;后者在激活函数反向计算时被调用;
2. 截断对象不同:前者截断可学习参数的梯度,后者截断回传给前层的梯度;
除此之外,还可以通过减小学习律或者对数据进行归一化处理来解决这类问题。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
------------------------------------------------------------------------
------------------
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载:
pip uninstall py_paddle paddle
pip uninstall py_paddle paddle
...
@@ -329,7 +361,7 @@ pip uninstall py_paddle paddle
...
@@ -329,7 +361,7 @@ pip uninstall py_paddle paddle
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl
16. PaddlePaddle存储的参数格式是什么,如何和明文进行相互转化
16. PaddlePaddle存储的参数格式是什么,如何和明文进行相互转化
---------------------------------------------------------
---------------------------------------------------------
------------
PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数两部分组成。头信息中,1~4字节表示PaddlePaddle版本信息,请直接填充0;5~8字节表示每个参数占用的字节数,当保存的网络参数为float类型时为4,double类型时为8;9~16字节表示保存的参数总个数。
PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数两部分组成。头信息中,1~4字节表示PaddlePaddle版本信息,请直接填充0;5~8字节表示每个参数占用的字节数,当保存的网络参数为float类型时为4,double类型时为8;9~16字节表示保存的参数总个数。
...
@@ -381,7 +413,7 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
...
@@ -381,7 +413,7 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
parameters.set('emb', load_parameter(emb_param_file, 30000, 256))
parameters.set('emb', load_parameter(emb_param_file, 30000, 256))
18. 集群多节点训练,日志中保存均为网络通信类错误
18. 集群多节点训练,日志中保存均为网络通信类错误
------------------------------
------------------------------
-----------------------------
集群多节点训练,日志报错为网络通信类错误,比如 :code:`Connection reset by peer` 等。
集群多节点训练,日志报错为网络通信类错误,比如 :code:`Connection reset by peer` 等。
此类报错通常是由于某一个节点的错误导致这个节点的训练进程退出,从而引发其他节点无法连接导致,可以参考下面的步骤排查:
此类报错通常是由于某一个节点的错误导致这个节点的训练进程退出,从而引发其他节点无法连接导致,可以参考下面的步骤排查:
...
@@ -392,8 +424,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
...
@@ -392,8 +424,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
19.
PaddlePaddle如何输出多个层
19.
如何调用 infer 接口输出多个layer的预测结果
------------------------------
------------------------------
-----------------------------
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
...
@@ -405,9 +437,28 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
...
@@ -405,9 +437,28 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
.. code-block:: python
.. code-block:: python
out = inferer.infer(input=data_batch, flatten_result=False, field=["value"])
out = inferer.infer(input=data_batch, field=["value"])
需要注意的是:
这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。
* 如果指定了2个layer作为输出层,实际上需要的输出结果是两个矩阵;
* 假设第一个layer的输出A是一个 N1 * M1 的矩阵,第二个 Layer 的输出B是一个 N2 * M2 的矩阵;
* paddle.v2 默认会将A和B 横向拼接,当N1 和 N2 大小不一样时,会报如下的错误:
.. code-block:: python
ValueError: all the input array dimensions except for the concatenation axis must match exactly
多个层的输出矩阵的高度不一致导致拼接失败,这种情况常常发生在:
* 同时输出序列层和非序列层;
* 多个输出层处理多个不同长度的序列;
此时可以在调用infer接口时通过设置 :code:`flatten_result=False` , 跳过“拼接”步骤,来解决上面的问题。这时,infer接口的返回值是一个python list:
* list 中元素的个数等于网络中输出层的个数;
* list 中每个元素是一个layer的输出结果矩阵,类型是numpy的ndarray;
* 每一个layer输出矩阵的高度,在非序列输入时:等于样本数;序列输入时等于:输入序列中元素的总数;宽度等于配置中layer的size;
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
-------------------------------------------------------------
-------------------------------------------------------------
...
@@ -416,8 +467,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
...
@@ -416,8 +467,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
21.
dropout 使用
21.
两种使用 drop_out 的方法有何区别?
-----------------
-----------------
------------------------------------
* 在PaddlePaddle中使用dropout有两种方式
* 在PaddlePaddle中使用dropout有两种方式
...
@@ -503,7 +554,7 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu
...
@@ -503,7 +554,7 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu
optimizer = paddle.optimizer.Adam(
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_schedule="manual",
learning_rate_args="1:1.0,2:0.9,3:0.8",)
learning_rate_args="1:1.0,2:0.9,3:0.8",)
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
...
@@ -512,3 +563,30 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu
...
@@ -512,3 +563,30 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
24. PaddlePaddle 中不同的 recurrent layer 的区别
--------------------------------------------------
以LSTM为例,在PaddlePaddle中包含以下 recurrent layer:
* :code:`paddle.layer.lstmemory`
* :code:`paddle.networks.simple_lstm`
* :code:`paddle.networks.lstmemory_group`
* :code:`paddle.networks.bidirectional_lstm`
按照具体实现方式可以归纳为2类:
1. 由 recurrent_group 实现的 recurrent layer:
* 用户在使用这一类recurrent layer时,可以访问由recurrent unit在一个时间步内计算得到的中间值(例如:hidden states, memory cells等);
* 上述的 :code:`paddle.networks.lstmemory_group` 是这一类的 recurrent layer ;
2. 将recurrent layer作为一个整体来实现:
* 用户在使用这一类recurrent layer,只能访问它们的输出值;
* 上述的 :code:`paddle.networks.lstmemory_group` 、 :code:`paddle.networks.simple_lstm` 和 :code:`paddle.networks.bidirectional_lstm` 属于这一类的实现;
将recurrent layer作为一个整体来实现, 能够针对CPU和GPU的计算做更多优化, 所以相比于recurrent group的实现方式, 第二类 recurrent layer 计算效率更高。 在实际应用中,如果用户不需要访问LSTM的中间变量,而只需要获得recurrent layer计算的输出,我们建议使用第二类实现。
此外,关于LSTM, PaddlePaddle中还包含 :code:`paddle.networks.lstmemory_unit` 这一计算单元:
* 不同于上述介绍的recurrent layer , :code:`paddle.networks.lstmemory_unit` 定义了LSTM单元在一个时间步内的计算过程,它并不是一个完整的recurrent layer,也不能接收序列数据作为输入;
* :code:`paddle.networks.lstmemory_unit` 只能在recurrent_group中作为step function使用;
doc/getstarted/build_and_install/docker_install_cn.rst
浏览文件 @
62d597c1
...
@@ -20,7 +20,7 @@ Docker使用入门
...
@@ -20,7 +20,7 @@ Docker使用入门
docker pull paddlepaddle/paddle:0.10.0
docker pull paddlepaddle/paddle:0.10.0
来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。
来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用
d
ocker.paddlepaddle.org/paddle下载。
- *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。
- *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。
实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。
实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。
...
...
paddle/gserver/activations/MKLDNNActivation.cpp
浏览文件 @
62d597c1
...
@@ -27,31 +27,53 @@ static ClassRegistrar<ActivationFunction> gMKLDNNActivationRegistrar;
...
@@ -27,31 +27,53 @@ static ClassRegistrar<ActivationFunction> gMKLDNNActivationRegistrar;
#define MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) mkldnn_##ACT_TYPE##Activation
#define MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) mkldnn_##ACT_TYPE##Activation
/**
/**
* @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
* @def BEGIN_MKLDNN_ACTIVATION
*/
#define BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) : public BASE_CLASS {
/**
* @def END_MKLDNN_ACTIVATION
*/
*/
#define DEFINE_MKLDNN_ELTWISE_ACTIVATION(ACT_TYPE, ALPHA, BWD_ALPHA) \
#define END_MKLDNN_ACTIVATION(ACT_TYPE) \
class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) \
private: \
: public MKLDNNEltwiseActivation { \
static const std::string name; \
private: \
\
static const std::string name; \
public: \
static const float alpha; \
const std::string& getName() const { return name; } \
static const float bwdAlpha; \
} \
\
; \
public: \
const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
const std::string& getName() const { return name; } \
"mkldnn_" #ACT_TYPE; \
float getAlpha() const { return alpha; } \
static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
float getBwdAlpha() const { return bwdAlpha; } \
gMKLDNNActivationRegistrar \
}; \
.registerClass<MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)>( \
const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
"mkldnn_" #ACT_TYPE); \
"mkldnn_" #ACT_TYPE; \
const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA; \
static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
gMKLDNNActivationRegistrar \
.registerClass<MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)>( \
"mkldnn_" #ACT_TYPE); \
});
});
/**
* @def DEFINE_MKLDNN_ACTIVATION
*/
#define DEFINE_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
END_MKLDNN_ACTIVATION(ACT_TYPE)
/**
* @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
*/
#define DEFINE_MKLDNN_ELTWISE_ACTIVATION( \
ACT_TYPE, BASE_CLASS, ALPHA, BWD_ALPHA) \
BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
private: \
static const float alpha; \
static const float bwdAlpha; \
\
public: \
float getAlpha() const { return alpha; } \
float getBwdAlpha() const { return bwdAlpha; } \
END_MKLDNN_ACTIVATION(ACT_TYPE) \
const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA;
/**
/**
* @brief MKLDNN Relu Activation.
* @brief MKLDNN Relu Activation.
* Actually mkldnn_relu is Leaky Relu.
* Actually mkldnn_relu is Leaky Relu.
...
@@ -59,19 +81,129 @@ static ClassRegistrar<ActivationFunction> gMKLDNNActivationRegistrar;
...
@@ -59,19 +81,129 @@ static ClassRegistrar<ActivationFunction> gMKLDNNActivationRegistrar;
* f(x) = negative_slope * x (x < 0)
* f(x) = negative_slope * x (x < 0)
* @note the negative_slope should be -0.f in forward
* @note the negative_slope should be -0.f in forward
*/
*/
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
relu
,
-
0.
f
,
0.
f
)
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
relu
,
MKLDNNEltwiseActivation
,
-
0.
f
,
0.
f
)
/**
/**
* @brief MKLDNN Tanh Activation.
* @brief MKLDNN Tanh Activation.
*/
*/
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
tanh
,
0.
f
,
0.
f
)
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
tanh
,
MKLDNNEltwiseActivation
,
0.
f
,
0.
f
)
/**
/**
* @brief MKLDNN ELU(Exponential Linear Unit) Activation.
* @brief MKLDNN ELU(Exponential Linear Unit) Activation.
* f(x) = x (x >= 0)
* f(x) = x (x >= 0)
* f(x) = negative_slope * (exp(x) - 1) (x < 0)
* f(x) = negative_slope * (exp(x) - 1) (x < 0)
*/
*/
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
elu
,
0.
f
,
0.
f
)
DEFINE_MKLDNN_ELTWISE_ACTIVATION
(
elu
,
MKLDNNEltwiseActivation
,
0.
f
,
0.
f
)
mkldnn
::
algorithm
MKLDNNEltwiseActivation
::
getAlgo
(
std
::
string
type
)
const
{
const
std
::
map
<
std
::
string
,
mkldnn
::
algorithm
>
algoMap
=
{
{
"relu"
,
algorithm
::
eltwise_relu
},
{
"tanh"
,
algorithm
::
eltwise_tanh
},
{
"elu"
,
algorithm
::
eltwise_elu
}};
type
.
erase
(
0
,
7
);
// remove mkldnn_
algorithm
algo
=
(
algorithm
)
0
;
mapGet
(
type
,
algoMap
,
&
algo
);
return
algo
;
}
void
MKLDNNEltwiseActivation
::
resetFwd
(
Argument
&
act
)
{
if
(
cnt_
==
act
.
value
->
getElementCnt
())
{
return
;
}
MKLDNNActivation
::
resetFwd
(
act
);
// note: alpha represents the NegativeSlope when used in relu.
float
alpha
=
getAlpha
();
float
beta
=
getBeta
();
algorithm
algo
=
getAlgo
(
this
->
getName
());
auto
fwdDesc
=
eltwise_fwd
::
desc
(
mkldnn
::
prop_kind
::
forward_training
,
algo
,
val_
->
getMemoryDesc
(),
alpha
,
beta
);
fwdPD_
.
reset
(
new
eltwise_fwd
::
primitive_desc
(
fwdDesc
,
*
engine_
));
// use inplace for forward but save input value before submit
inVal_
=
val_
;
copyInVal_
=
nullptr
;
if
(
act
.
grad
&&
algo
==
algorithm
::
eltwise_tanh
)
{
// tanh need save src input for backward
inVal_
=
MKLDNNMatrix
::
create
(
nullptr
,
val_
->
getPrimitiveDesc
());
copyInVal_
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
val_
,
*
inVal_
);
CHECK
(
copyInVal_
)
<<
"should not be emptry"
;
pipelineFwd_
.
push_back
(
*
copyInVal_
);
}
fwd_
.
reset
(
new
eltwise_fwd
(
*
fwdPD_
,
*
val_
,
*
val_
));
pipelineFwd_
.
push_back
(
*
fwd_
);
needResetBwd_
=
true
;
}
void
MKLDNNEltwiseActivation
::
resetBwd
(
Argument
&
act
)
{
if
(
!
needResetBwd_
)
{
return
;
}
VLOG
(
MKLDNN_BASE
)
<<
getName
()
<<
" reset mkldnn backward"
;
needResetBwd_
=
false
;
algorithm
algo
=
getAlgo
(
this
->
getName
());
float
alpha
=
getBwdAlpha
();
float
beta
=
getBeta
();
grad_
=
MKLDNNMatrix
::
create
(
act
.
grad
,
val_
->
getPrimitiveDesc
());
auto
eng
=
CPUEngine
::
Instance
().
getEngine
();
auto
bwdDesc
=
eltwise_bwd
::
desc
(
algo
,
grad_
->
getMemoryDesc
(),
val_
->
getMemoryDesc
(),
alpha
,
beta
);
auto
bwdPD
=
eltwise_bwd
::
primitive_desc
(
bwdDesc
,
eng
,
*
fwdPD_
);
CHECK
(
inVal_
);
bwd_
.
reset
(
new
eltwise_bwd
(
bwdPD
,
*
inVal_
,
*
grad_
,
*
grad_
));
pipelineBwd_
.
clear
();
pipelineBwd_
.
push_back
(
*
bwd_
);
}
/**
* @brief MKLDNN Softmax Activation
*/
DEFINE_MKLDNN_ACTIVATION
(
softmax
,
MKLDNNSoftmaxActivation
)
void
MKLDNNSoftmaxActivation
::
resetFwd
(
Argument
&
act
)
{
if
(
cnt_
==
act
.
value
->
getElementCnt
())
{
return
;
}
MKLDNNActivation
::
resetFwd
(
act
);
int
axis
=
1
;
auto
fwdDesc
=
softmax_fwd
::
desc
(
mkldnn
::
prop_kind
::
forward_scoring
,
val_
->
getMemoryDesc
(),
axis
);
auto
fwdPD
=
softmax_fwd
::
primitive_desc
(
fwdDesc
,
*
engine_
);
fwd_
.
reset
(
new
softmax_fwd
(
fwdPD
,
*
val_
,
*
val_
));
pipelineFwd_
.
push_back
(
*
fwd_
);
}
Error
__must_check
MKLDNNSoftmaxActivation
::
forward
(
Argument
&
act
)
{
resetFwd
(
act
);
stream_
->
submit
(
pipelineFwd_
);
real
*
v
=
act
.
value
->
getData
();
real
threshold
=
exp
(
-
64
);
#pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
act
.
value
->
getElementCnt
();
++
i
)
{
v
[
i
]
=
v
[
i
]
<
threshold
?
threshold
:
v
[
i
];
}
return
Error
();
}
Error
__must_check
MKLDNNSoftmaxActivation
::
backward
(
Argument
&
act
)
{
MatrixPtr
outputV
=
act
.
value
;
MatrixPtr
outputG
=
act
.
grad
;
Matrix
::
resizeOrCreate
(
sftMaxDot_
,
outputG
->
getHeight
(),
outputG
->
getWidth
(),
/* trans */
false
,
/* useGpu */
false
);
Matrix
::
resizeOrCreate
(
sftMaxSum_
,
outputG
->
getHeight
(),
1
,
/* trans */
false
,
/* useGpu */
false
);
sftMaxDot_
->
dotMul
(
*
outputG
,
*
outputV
);
sftMaxSum_
->
colMerge
(
*
sftMaxDot_
);
act
.
grad
->
softmaxDerivative
(
*
act
.
value
,
*
sftMaxSum_
);
return
Error
();
}
ActivationFunction
*
MKLDNNActivation
::
create
(
const
std
::
string
&
type
)
{
ActivationFunction
*
MKLDNNActivation
::
create
(
const
std
::
string
&
type
)
{
return
gMKLDNNActivationRegistrar
.
createByType
(
type
);
return
gMKLDNNActivationRegistrar
.
createByType
(
type
);
...
@@ -84,4 +216,34 @@ std::vector<std::string> MKLDNNActivation::getAllRegisteredTypes() {
...
@@ -84,4 +216,34 @@ std::vector<std::string> MKLDNNActivation::getAllRegisteredTypes() {
return
types
;
return
types
;
}
}
void
MKLDNNActivation
::
resetFwd
(
Argument
&
act
)
{
VLOG
(
MKLDNN_BASE
)
<<
getName
()
<<
" reset mkldnn forward"
;
cnt_
=
act
.
value
->
getElementCnt
();
pipelineFwd_
.
clear
();
stream_
.
reset
(
new
MKLDNNStream
());
engine_
.
reset
(
new
mkldnn
::
engine
(
mkldnn
::
engine
::
cpu
,
0
));
val_
=
std
::
dynamic_pointer_cast
<
MKLDNNMatrix
>
(
act
.
value
);
if
(
val_
==
nullptr
)
{
int
bs
=
act
.
getBatchSize
();
int
ih
=
act
.
getFrameHeight
()
>
0
?
act
.
getFrameHeight
()
:
1
;
int
iw
=
act
.
getFrameWidth
()
>
0
?
act
.
getFrameWidth
()
:
1
;
int
ic
=
cnt_
/
bs
/
ih
/
iw
;
CHECK_EQ
(
cnt_
,
(
size_t
)
bs
*
ic
*
ih
*
iw
);
val_
=
MKLDNNMatrix
::
create
(
act
.
value
,
{
bs
,
ic
,
ih
,
iw
},
mkldnn
::
memory
::
format
::
nchw
,
*
engine_
);
CHECK
(
val_
);
val_
->
downSpatial
();
}
}
Error
__must_check
MKLDNNActivation
::
forward
(
Argument
&
act
)
{
resetFwd
(
act
);
stream_
->
submit
(
pipelineFwd_
);
return
Error
();
}
Error
__must_check
MKLDNNActivation
::
backward
(
Argument
&
act
)
{
resetBwd
(
act
);
stream_
->
submit
(
pipelineBwd_
);
return
Error
();
}
}
// namespace paddle
}
// namespace paddle
paddle/gserver/activations/MKLDNNActivation.h
浏览文件 @
62d597c1
...
@@ -36,6 +36,7 @@ protected:
...
@@ -36,6 +36,7 @@ protected:
// mkldnn matrix, primitive, stream and pipeline
// mkldnn matrix, primitive, stream and pipeline
MKLDNNMatrixPtr
val_
;
MKLDNNMatrixPtr
val_
;
MKLDNNMatrixPtr
grad_
;
MKLDNNMatrixPtr
grad_
;
std
::
shared_ptr
<
mkldnn
::
engine
>
engine_
;
std
::
shared_ptr
<
MKLDNNStream
>
stream_
;
std
::
shared_ptr
<
MKLDNNStream
>
stream_
;
std
::
shared_ptr
<
mkldnn
::
primitive
>
fwd_
;
std
::
shared_ptr
<
mkldnn
::
primitive
>
fwd_
;
std
::
shared_ptr
<
mkldnn
::
primitive
>
bwd_
;
std
::
shared_ptr
<
mkldnn
::
primitive
>
bwd_
;
...
@@ -48,8 +49,18 @@ public:
...
@@ -48,8 +49,18 @@ public:
static
ActivationFunction
*
create
(
const
std
::
string
&
type
);
static
ActivationFunction
*
create
(
const
std
::
string
&
type
);
static
std
::
vector
<
std
::
string
>
getAllRegisteredTypes
();
static
std
::
vector
<
std
::
string
>
getAllRegisteredTypes
();
virtual
const
std
::
string
&
getName
()
const
=
0
;
virtual
const
std
::
string
&
getName
()
const
=
0
;
virtual
Error
__must_check
forward
(
Argument
&
act
)
=
0
;
/**
virtual
Error
__must_check
backward
(
Argument
&
act
)
=
0
;
* reset the forward primitives
*/
virtual
void
resetFwd
(
Argument
&
act
);
/**
* reset the backward primitives,
* can not merge this functions into resetFwd as the grad data
* would be changing before backward.
*/
virtual
void
resetBwd
(
Argument
&
act
)
{}
virtual
Error
__must_check
forward
(
Argument
&
act
);
virtual
Error
__must_check
backward
(
Argument
&
act
);
};
};
/**
/**
...
@@ -59,6 +70,7 @@ public:
...
@@ -59,6 +70,7 @@ public:
class
MKLDNNEltwiseActivation
:
public
MKLDNNActivation
{
class
MKLDNNEltwiseActivation
:
public
MKLDNNActivation
{
typedef
mkldnn
::
eltwise_forward
eltwise_fwd
;
typedef
mkldnn
::
eltwise_forward
eltwise_fwd
;
typedef
mkldnn
::
eltwise_backward
eltwise_bwd
;
typedef
mkldnn
::
eltwise_backward
eltwise_bwd
;
typedef
mkldnn
::
algorithm
algorithm
;
protected:
protected:
// save the forward primitive desc, which can be used backward
// save the forward primitive desc, which can be used backward
...
@@ -70,9 +82,7 @@ protected:
...
@@ -70,9 +82,7 @@ protected:
public:
public:
MKLDNNEltwiseActivation
()
{}
MKLDNNEltwiseActivation
()
{}
~
MKLDNNEltwiseActivation
()
{}
~
MKLDNNEltwiseActivation
()
{}
virtual
const
std
::
string
&
getName
()
const
=
0
;
virtual
const
std
::
string
&
getName
()
const
=
0
;
// in common, the alpha of forward and backward should be equal.
// in common, the alpha of forward and backward should be equal.
...
@@ -80,105 +90,30 @@ public:
...
@@ -80,105 +90,30 @@ public:
virtual
float
getAlpha
()
const
=
0
;
virtual
float
getAlpha
()
const
=
0
;
virtual
float
getBwdAlpha
()
const
=
0
;
virtual
float
getBwdAlpha
()
const
=
0
;
virtual
float
getBeta
()
const
{
return
0.
f
;
}
virtual
float
getBeta
()
const
{
return
0.
f
;
}
virtual
mkldnn
::
algorithm
getAlgo
(
const
std
::
string
&
type
)
const
{
virtual
algorithm
getAlgo
(
std
::
string
type
)
const
;
if
(
type
==
"mkldnn_relu"
)
{
void
resetFwd
(
Argument
&
act
)
override
;
return
mkldnn
::
algorithm
::
eltwise_relu
;
void
resetBwd
(
Argument
&
act
)
override
;
}
else
if
(
type
==
"mkldnn_tanh"
)
{
};
return
mkldnn
::
algorithm
::
eltwise_tanh
;
}
else
if
(
type
==
"mkldnn_elu"
)
{
return
mkldnn
::
algorithm
::
eltwise_elu
;
}
else
{
LOG
(
FATAL
)
<<
"Unkown eltwise activation type: "
<<
type
;
}
return
(
mkldnn
::
algorithm
)
0
;
}
/**
* reshape and reset the forward primitives
*/
void
resetFwd
(
Argument
&
act
)
{
if
(
cnt_
==
act
.
value
->
getElementCnt
())
{
return
;
}
VLOG
(
MKLDNN_BASE
)
<<
getName
()
<<
" reset mkldnn forward"
;
cnt_
=
act
.
value
->
getElementCnt
();
stream_
.
reset
(
new
MKLDNNStream
());
auto
eng
=
CPUEngine
::
Instance
().
getEngine
();
// get algo setting
mkldnn
::
algorithm
algo
=
getAlgo
(
this
->
getName
());
// note: alpha represents the NegativeSlope when used in relu.
float
alpha
=
getAlpha
();
float
beta
=
getBeta
();
pipelineFwd_
.
clear
();
val_
=
std
::
dynamic_pointer_cast
<
MKLDNNMatrix
>
(
act
.
value
);
if
(
val_
==
nullptr
)
{
int
bs
=
act
.
getBatchSize
();
int
ih
=
act
.
getFrameHeight
()
>
0
?
act
.
getFrameHeight
()
:
1
;
int
iw
=
act
.
getFrameWidth
()
>
0
?
act
.
getFrameWidth
()
:
1
;
int
ic
=
cnt_
/
bs
/
ih
/
iw
;
CHECK_EQ
(
cnt_
,
(
size_t
)
bs
*
ic
*
ih
*
iw
);
val_
=
MKLDNNMatrix
::
create
(
act
.
value
,
{
bs
,
ic
,
ih
,
iw
},
mkldnn
::
memory
::
format
::
nchw
,
eng
);
CHECK
(
val_
);
}
auto
fwdDesc
=
eltwise_fwd
::
desc
(
mkldnn
::
prop_kind
::
forward_training
,
algo
,
val_
->
getMemoryDesc
(),
alpha
,
beta
);
fwdPD_
.
reset
(
new
eltwise_fwd
::
primitive_desc
(
fwdDesc
,
eng
));
// use inplace for forward but save input value before submit
inVal_
=
val_
;
copyInVal_
=
nullptr
;
if
(
act
.
grad
&&
algo
==
mkldnn
::
algorithm
::
eltwise_tanh
)
{
// tanh need save src input for backward
inVal_
=
MKLDNNMatrix
::
create
(
nullptr
,
val_
->
getPrimitiveDesc
());
copyInVal_
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
val_
,
*
inVal_
);
CHECK
(
copyInVal_
)
<<
"should not be emptry"
;
pipelineFwd_
.
push_back
(
*
copyInVal_
);
}
fwd_
.
reset
(
new
eltwise_fwd
(
*
fwdPD_
,
*
val_
,
*
val_
));
pipelineFwd_
.
push_back
(
*
fwd_
);
needResetBwd_
=
true
;
}
/**
/**
* reset the backward primitives, can not merge into resetFwd as the grad data
* @brief Base class of MKLDNN softmax Activation,
* would be changing before backward.
* only have mkldnn forward, use cpu implement for backward.
*/
*/
void
resetBwd
(
Argument
&
act
)
{
class
MKLDNNSoftmaxActivation
:
public
MKLDNNActivation
{
if
(
!
needResetBwd_
)
{
typedef
mkldnn
::
softmax_forward
softmax_fwd
;
return
;
}
VLOG
(
MKLDNN_BASE
)
<<
getName
()
<<
" reset mkldnn backward"
;
needResetBwd_
=
false
;
mkldnn
::
algorithm
algo
=
getAlgo
(
this
->
getName
());
float
alpha
=
getBwdAlpha
();
float
beta
=
getBeta
();
grad_
=
MKLDNNMatrix
::
create
(
act
.
grad
,
val_
->
getPrimitiveDesc
());
auto
eng
=
CPUEngine
::
Instance
().
getEngine
();
auto
bwdDesc
=
eltwise_bwd
::
desc
(
algo
,
grad_
->
getMemoryDesc
(),
val_
->
getMemoryDesc
(),
alpha
,
beta
);
auto
bwdPD
=
eltwise_bwd
::
primitive_desc
(
bwdDesc
,
eng
,
*
fwdPD_
);
CHECK
(
inVal_
);
bwd_
.
reset
(
new
eltwise_bwd
(
bwdPD
,
*
inVal_
,
*
grad_
,
*
grad_
));
pipelineBwd_
.
clear
();
pipelineBwd_
.
push_back
(
*
bwd_
);
}
Error
__must_check
forward
(
Argument
&
act
)
{
private:
resetFwd
(
act
);
// for backward
stream_
->
submit
(
pipelineFwd_
);
MatrixPtr
sftMaxSum_
;
return
Error
();
MatrixPtr
sftMaxDot_
;
}
Error
__must_check
backward
(
Argument
&
act
)
{
public:
resetBwd
(
act
);
MKLDNNSoftmaxActivation
()
{}
stream_
->
submit
(
pipelineBwd_
);
~
MKLDNNSoftmaxActivation
()
{}
return
Error
();
virtual
const
std
::
string
&
getName
()
const
=
0
;
}
void
resetFwd
(
Argument
&
act
)
override
;
Error
__must_check
forward
(
Argument
&
act
)
override
;
Error
__must_check
backward
(
Argument
&
act
)
override
;
};
};
}
// namespace paddle
}
// namespace paddle
paddle/gserver/tests/test_MKLDNN.cpp
浏览文件 @
62d597c1
...
@@ -222,8 +222,8 @@ static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) {
...
@@ -222,8 +222,8 @@ static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) {
}
}
void
testActivation
(
std
::
string
&
actType
,
const
testActDesc
&
pm
)
{
void
testActivation
(
std
::
string
&
actType
,
const
testActDesc
&
pm
)
{
// TODO(TJ):
mkldnn_softmax not implemented, paddle do not have
elu activation
// TODO(TJ):
remove me when paddle support
elu activation
if
(
actType
==
"mkldnn_
softmax"
||
actType
==
"mkldnn_
elu"
)
{
if
(
actType
==
"mkldnn_elu"
)
{
return
;
return
;
}
}
const
std
::
string
compareTypes
[]
=
{
actType
,
actType
.
erase
(
0
,
7
)};
const
std
::
string
compareTypes
[]
=
{
actType
,
actType
.
erase
(
0
,
7
)};
...
...
paddle/operators/accuracy_op.cu
浏览文件 @
62d597c1
...
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
...
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return
;
return
;
}
}
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
}
}
};
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
62d597c1
...
@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label) must not be null."
);
"Input(Label) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
must
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
"be equal."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft
_label) == true, T
he 2nd dimension of "
"If Attr(soft
Label) == true, t
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft
_label) == false, T
he 2nd dimension of "
"If Attr(soft
Label) == false, t
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"Input(Label)
should
be 1."
);
}
}
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
],
1
});
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
],
1
});
...
@@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label)
must not be
null."
);
"Input(Label)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) must not be null."
);
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
"be equal."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
dy
->
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
dy
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Y@Grad)
must
"
"The 1st dimension of Input(X) and Input(Y@Grad)
should
"
"be equal."
);
"be equal."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
"The 2nd dimension of Input(Y@Grad)
must
be 1."
);
"The 2nd dimension of Input(Y@Grad)
should
be 1."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"
If Attr(soft_label) == true, T
he 2nd dimension of "
"
When Attr(softLabel) == true, t
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"
If Attr(soft_label) == false, T
he 2nd dimension of "
"
When Attr(softLabel) == false, t
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"Input(Label)
should
be 1."
);
}
}
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
@@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of CrossEntropyOp"
);
AddInput
(
"X"
,
AddInput
(
"Label"
,
"The second input of CrossEntropyOp"
);
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
AddOutput
(
"Y"
,
"The output of CrossEntropyOp"
);
"where N is the batch size and D is the number of classes. "
AddAttr
<
bool
>
(
"soft_label"
,
"Is soft label. Default zero."
)
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."
);
AddInput
(
"Label"
,
"(Tensor, default Tensor<int>), the ground truth which is "
"a 2-D tensor. "
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When softLabel is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Y"
,
"(Tensor, default Tensor<float>), a 2-D tensor "
"with shape [N x 1]. The cross entropy loss."
);
AddAttr
<
bool
>
(
"softLabel"
,
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
CrossEntropy Operator.
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
computation.
1) One-hot cross-entropy:
1) One-hot cross-entropy:
soft
_label = F
alse, Label[i, 0] indicates the class index for sample i:
soft
Label = f
alse, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
2) Soft-label cross-entropy:
soft
_label = T
rue, Label[i, j] indicates the soft label of class j
soft
Label = t
rue, Label[i, j] indicates the soft label of class j
for sample i:
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
62d597c1
...
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
...
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
tolerable_value
(
log
(
X
[
i
*
D
+
label
[
i
]]));
Y
[
i
]
=
-
TolerableValue
<
T
>
()
(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
T
sum_single_warp
(
T
val
)
{
val
+=
__shfl_down
(
val
,
16
);
val
+=
__shfl_down
(
val
,
8
);
val
+=
__shfl_down
(
val
,
4
);
val
+=
__shfl_down
(
val
,
2
);
val
+=
__shfl_down
(
val
,
1
);
return
val
;
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
const
int
class_num
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
int
tid
=
threadIdx
.
x
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
extern
__shared__
T
d_sum
[];
T
sum
=
static_cast
<
T
>
(
0
);
d_sum
[
tid
]
=
0
;
for
(
int
j
=
0
;
j
<
D
;
j
++
)
{
sum
+=
label
[
i
*
D
+
j
]
*
tolerable_value
(
log
(
X
[
i
*
D
+
j
]));
int
cur_idx
=
tid
;
}
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
Y
[
i
]
=
-
sum
;
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockDim
.
x
;
cur_idx
+=
blockDim
.
x
;
}
__syncthreads
();
for
(
unsigned
int
stride
=
blockDim
.
x
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
();
}
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
// TODO(qingqing): make zero setting a
n
common function.
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
template
<
typename
T
>
__global__
void
z
ero
(
T
*
X
,
const
int
N
)
{
__global__
void
Z
ero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
X
[
i
]
=
0.0
;
...
@@ -71,13 +94,10 @@ template <typename T>
...
@@ -71,13 +94,10 @@ template <typename T>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
const
int
D
)
{
// TOOD(qingqing): optimize for this kernel
int
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
if
(
ids
<
N
*
D
)
{
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
row_ids
=
ids
/
D
;
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
dX
[
ids
]
=
-
label
[
ids
]
*
dY
[
row_ids
]
/
X
[
ids
];
int
idx
=
i
*
D
+
j
;
dX
[
idx
]
=
-
label
[
idx
]
*
dY
[
i
]
/
X
[
idx
];
}
}
}
}
}
...
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
...
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"
It must use GPUPla
ce."
);
"
This kernel only runs on GPU devi
ce."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
auto
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
n
=
x
->
dims
()[
0
];
int
batch_size
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
SoftCrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
d
);
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
else
{
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
int
block
=
512
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
}
}
};
};
...
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
...
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
"This kernel only runs on GPU device."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
dy_data
=
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
)
);
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
data
<
T
>
(
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
const
T
*
x_data
=
x
->
data
<
T
>
(
);
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
batch_size
=
x
->
dims
()[
0
];
auto
*
dy_data
=
dy
->
data
<
T
>
();
int
class_num
=
x
->
dims
()[
1
];
auto
*
x_data
=
x
->
data
<
T
>
();
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
block
=
512
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
int
grid
=
(
batch_size
*
class_num
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dx_data
,
n
*
d
);
grid
=
(
n
+
block
-
1
)
/
block
;
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
label
->
data
<
T
>
();
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
SoftCrossEntropyGradientKernel
<
T
><<<
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
else
{
}
else
{
Zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
batch_size
*
class_num
);
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
label_data
=
label
->
data
<
int
>
();
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
grid
=
(
batch_size
+
block
-
1
)
/
block
;
label_data
,
n
,
d
);
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
}
}
};
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
62d597c1
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/hostdevice.h"
...
@@ -20,53 +21,51 @@ namespace paddle {
...
@@ -20,53 +21,51 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
template
<
typename
T
>
HOSTDEVICE
T
tolerable_value
(
const
T
x
)
{
struct
TolerableValue
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
const
T
kApproInf
=
1e20
;
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
if
(
x
==
INFINITY
)
{
const
T
kApproInf
=
1e20
;
return
kApproInf
;
if
(
x
==
INFINITY
)
return
kApproInf
;
if
(
x
==
-
INFINITY
)
return
-
kApproInf
;
return
x
;
}
}
if
(
x
==
-
INFINITY
)
{
};
return
-
kApproInf
;
}
return
x
;
}
template
<
typename
T
>
template
<
typename
T
>
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
"This kernel only runs on CPU."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
int
batch_size
=
x
->
dims
()[
0
];
auto
*
y_data
=
y
->
data
<
T
>
();
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
auto
prob
=
EigenMatrix
<
T
>::
From
(
*
x
);
int
batch_size
=
x
->
dims
()[
0
];
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
int
class_num
=
x
->
dims
()[
1
];
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
loss
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
-
((
lbl_mat
*
prob
.
log
().
unaryExpr
(
TolerableValue
<
T
>
()))
int
index
=
0
;
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
sum
+=
label_data
[
index
]
*
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
sum
;
index
++
;
}
}
}
else
{
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
const
int
class_num
=
x
->
dims
()[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
labels
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
int
index
=
i
*
class_num
+
label_data
[
i
];
y_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
TolerableValue
<
T
>
()
(
std
::
log
(
x_data
[
index
]));
}
}
}
}
}
}
...
@@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
...
@@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"
It must use CPUPlace
."
);
"
This kernel only runs on CPU
."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
)
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
)
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y
"
));
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X
"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()
);
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
// TODO(qingqing): make zero setting an common function.
auto
x_mat
=
EigenMatrix
<
T
>::
From
(
*
x
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
dy_mat
=
EigenMatrix
<
T
>::
From
(
*
dy
);
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
label
);
int
index
=
0
;
auto
dx_mat
=
EigenMatrix
<
T
>::
From
(
*
dx
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
dx_mat
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
dx_data
[
index
]
=
-
label_data
[
index
]
*
dy_data
[
i
]
/
x_data
[
index
];
-
(
lbl_mat
*
dy_mat
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
class_num
))
/
index
++
;
x_mat
);
}
}
}
else
{
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
int
batch_size
=
x
->
dims
()[
0
];
const
T
*
dy_data
=
dy
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
label
->
data
<
int
>
();
// TODO(qingqing): make zero setting a common function.
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int
index
=
i
*
class_num
+
label_data
[
i
];
int
index
=
i
*
class_num
+
label_data
[
i
];
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
62d597c1
...
@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
...
@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
}
}
};
};
...
@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
...
@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
K
,
D
);
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
};
};
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
62d597c1
...
@@ -18,7 +18,6 @@ namespace paddle {
...
@@ -18,7 +18,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
class
MultiplexOp
:
public
framework
::
OperatorWithKernel
{
class
MultiplexOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
...
@@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
.
MultiInputVar
(
"X"
).
empty
(),
PADDLE_ENFORCE
(
!
ctx
.
MultiInputVar
(
"X"
).
empty
(),
"
Input(X) should not be null
"
);
"
MultiInput(X) shouldn't be empty.
"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
"Output(Out) shouldn't be null."
);
"Output(Out) shouldn't be null."
);
auto
ids_dim
=
ctx
.
Input
<
Tensor
>
(
"Ids"
)
->
dims
();
PADDLE_ENFORCE
(
ids_dim
.
size
()
==
2
&&
ids_dim
[
1
]
==
1
,
"The index tensor must be a vector with size batchSize x 1."
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoD
Tensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
num_ins
=
ins
.
size
();
auto
num_ins
=
ins
.
size
();
PADDLE_ENFORCE
(
num_ins
>
2
,
PADDLE_ENFORCE
(
num_ins
>
1
,
"multiplex operator should have more than 2 inputs."
);
"multiplex operator should have more than "
PADDLE_ENFORCE_EQ
(
ins
[
0
]
->
dims
().
size
(),
1
,
"one candidate input tensors."
);
"The first input must be a index vector."
);
auto
in_dim
=
ins
[
1
]
->
dims
();
auto
in_dim
=
ins
[
0
]
->
dims
();
PADDLE_ENFORCE
(
in_dim
.
size
()
>=
2
,
for
(
size_t
i
=
2
;
i
<
num_ins
;
i
++
)
{
"The rank of candidate tensors must be not less than 2."
);
for
(
size_t
i
=
1
;
i
<
num_ins
;
i
++
)
{
auto
dim
=
ins
[
i
]
->
dims
();
auto
dim
=
ins
[
i
]
->
dims
();
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
in_dim
==
dim
,
in_dim
==
dim
,
"All the candidate tensors must have the same size."
);
"All the input tensors except the first one must have the same size"
);
}
}
out
->
Resize
(
in_dim
);
out
->
Resize
(
in_dim
);
}
}
...
@@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker
(
framework
::
OpProto
*
proto
,
MultiplexOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensors of multiplex operator."
).
AsDuplicable
();
AddInput
(
"Ids"
,
"The index tensor of multiplex operator."
);
AddInput
(
"X"
,
"The candidate tensors of multiplex operator."
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"The output tensor of multiplex operator."
);
AddOutput
(
"Out"
,
"The output tensor of multiplex operator."
);
AddComment
(
R"DOC(Multiplex operator
AddComment
(
R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
Multiplex multiple tensors according to the index provided by the index tensor.
input tensor.
ins[0]
: the index tensor.
Ids
: the index tensor.
ins[1:N]: the candidate output tensors
.
X[0 : N - 1]: the candidate tensors for output (N >= 2)
.
For each index i from 0 to batchSize - 1, the output is the i-th row of the
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (
index[i] + 1
)-th tensor.
the (
Ids[i]
)-th tensor.
For i-th row of the output tensor:
For i-th row of the output tensor:
y[i]
[j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
y[i]
= x_{k}[i]
where y is the output tensor. `x_{k}` is the k-th input tensor
where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`.
and `k = Ids[i]`.
)DOC"
);
)DOC"
);
}
}
};
};
...
@@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
.
MultiInputVar
(
"X"
).
empty
(),
PADDLE_ENFORCE
(
!
ctx
.
MultiInputVar
(
"X"
).
empty
(),
"Input(X) should not be null"
);
"Input(X) should not be null
.
"
);
PADDLE_ENFORCE
(
!
ctx
.
MultiOutputVar
(
framework
::
GradVarName
(
"X"
)).
empty
(),
PADDLE_ENFORCE
(
!
ctx
.
MultiOutputVar
(
framework
::
GradVarName
(
"X"
)).
empty
(),
"Output(X@Grad) should not be null"
);
"Output(X@Grad) should not be null
.
"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should
n'
t be null."
);
"Input(Out@GRAD) should
no
t be null."
);
auto
d_ins
=
ctx
.
MultiOutput
<
LoD
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d_ins
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
//
don't compute gradient for index (ins[0]
)
//
No need to compute gradient for Input(Ids
)
for
(
size_t
i
=
1
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
Resize
(
ins
[
i
]
->
dims
());
d_ins
[
i
]
->
Resize
(
ins
[
i
]
->
dims
());
}
}
...
...
paddle/operators/multiplex_op.cu
浏览文件 @
62d597c1
...
@@ -18,27 +18,30 @@
...
@@ -18,27 +18,30 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MultiplexGPUKernel
:
public
framework
::
OpKernel
{
class
MultiplexGPUKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out
"
);
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids
"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
]
;
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
// copy index to cpu
// copy index to cpu
framework
::
Tensor
index_t_cpu
;
Tensor
index_t_cpu
;
index_t_cpu
.
CopyFrom
<
T
>
(
*
(
ins
[
0
])
,
platform
::
CPUPlace
());
index_t_cpu
.
CopyFrom
<
int32_t
>
(
*
ids
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
T
>
();
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
();
.
stream
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
int
k
=
(
int
)
index
[
i
]
+
1
;
int32_t
k
=
index
[
i
];
PADDLE_ENFORCE_GE
(
k
,
0
,
"index must be nonnegative."
);
PADDLE_ENFORCE_LT
(
k
,
ins
.
size
(),
PADDLE_ENFORCE_LT
(
k
,
ins
.
size
(),
"index exceeds the number of candidate tensors."
);
"index exceeds the number of candidate tensors."
);
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
...
@@ -51,11 +54,11 @@ template <typename Place, typename T>
...
@@ -51,11 +54,11 @@ template <typename Place, typename T>
class
MultiplexGradGPUKernel
:
public
framework
::
OpKernel
{
class
MultiplexGradGPUKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_out
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
d_ins
=
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d_ins
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
for
(
size_t
i
=
1
;
i
<
d_ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
d_ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
...
@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
...
@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
}
}
}
}
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
]
;
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
// copy index to cpu
// copy index to cpu
framework
::
Tensor
index_t_cpu
;
Tensor
index_t_cpu
;
index_t_cpu
.
CopyFrom
<
T
>
(
*
(
ins
[
0
])
,
platform
::
CPUPlace
());
index_t_cpu
.
CopyFrom
<
int32_t
>
(
*
ids
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
T
>
();
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
();
.
stream
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
int
k
=
(
int
)
index
[
i
]
+
1
;
size_t
k
=
static_cast
<
size_t
>
(
index
[
i
])
;
if
(
d_ins
[
k
])
{
if
(
d_ins
[
k
])
{
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
),
stream
);
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
),
stream
);
...
...
paddle/operators/multiplex_op.h
浏览文件 @
62d597c1
...
@@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel {
...
@@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
ids
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ids"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
]
;
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
auto
*
index
=
ins
[
0
]
->
data
<
T
>
();
auto
index
=
ids
->
data
<
int32_t
>
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
int
k
=
(
int
)
index
[
i
]
+
1
;
int32_t
k
=
index
[
i
];
PADDLE_ENFORCE_GE
(
k
,
0
,
"index must be nonnegative."
);
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
k
),
ins
.
size
(),
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
k
),
ins
.
size
(),
"index exceeds the number of candidate tensors."
);
"index exceeds the number of candidate tensors."
);
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
...
@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
...
@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
ids
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ids"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
d_ins
=
auto
d_ins
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
for
(
size_t
i
=
1
;
i
<
d_ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
d_ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
...
@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
...
@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
}
}
}
}
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
]
;
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
auto
*
index
=
i
ns
[
0
]
->
data
<
T
>
();
auto
*
index
=
i
ds
->
data
<
int32_t
>
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
int
k
=
(
int
)
index
[
i
]
+
1
;
size_t
k
=
static_cast
<
size_t
>
(
index
[
i
])
;
if
(
d_ins
[
k
])
{
if
(
d_ins
[
k
])
{
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
));
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
));
...
...
paddle/operators/top_k_op.cu
浏览文件 @
62d597c1
...
@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
...
@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
// NOTE: pass lds and dim same to input width.
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
dim3
threads
(
256
,
1
);
dim3
grid
(
input_height
,
1
);
dim3
grid
(
input_height
,
1
);
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
>>>
(
KeMatrixTopK
<
T
,
5
,
256
><<<
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
input_width
,
int
(
k
));
ctx
.
device_context
())
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
int
(
k
));
}
}
};
};
...
...
paddle/parameter/FirstOrderOptimizer.h
浏览文件 @
62d597c1
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "ParameterOptimizer.h"
#include "ParameterOptimizer.h"
#include "ParameterUpdateFunctions.h"
#include "Regularizer.h"
#include "Regularizer.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -37,6 +38,15 @@ public:
...
@@ -37,6 +38,15 @@ public:
real
torch_learningRate
=
optConfig_
.
learning_method
()
==
"torch_momentum"
real
torch_learningRate
=
optConfig_
.
learning_method
()
==
"torch_momentum"
?
1.0
-
paraConfig
.
momentum
()
?
1.0
-
paraConfig
.
momentum
()
:
1.0
;
:
1.0
;
#ifdef PADDLE_USE_MKLDNN
sgdUpdate
(
learningRate_
*
paraConfig
.
learning_rate
()
*
(
firstTime_
?
1.0
:
torch_learningRate
),
paraConfig
.
momentum
(),
applyDecay_
?
paraConfig
.
decay_rate
()
:
0
,
vecs
[
PARAMETER_VALUE
].
get
(),
vecs
[
PARAMETER_GRADIENT
].
get
(),
vecs
[
PARAMETER_MOMENTUM
].
get
());
#else
vecs
[
PARAMETER_VALUE
]
->
sgdUpdate
(
vecs
[
PARAMETER_VALUE
]
->
sgdUpdate
(
*
vecs
[
PARAMETER_GRADIENT
],
*
vecs
[
PARAMETER_GRADIENT
],
*
vecs
[
PARAMETER_MOMENTUM
],
*
vecs
[
PARAMETER_MOMENTUM
],
...
@@ -44,6 +54,7 @@ public:
...
@@ -44,6 +54,7 @@ public:
(
firstTime_
?
1.0
:
torch_learningRate
),
(
firstTime_
?
1.0
:
torch_learningRate
),
paraConfig
.
momentum
(),
paraConfig
.
momentum
(),
applyDecay_
?
paraConfig
.
decay_rate
()
:
0
);
applyDecay_
?
paraConfig
.
decay_rate
()
:
0
);
#endif
}
}
virtual
void
finishBatch
()
{
firstTime_
=
false
;
}
virtual
void
finishBatch
()
{
firstTime_
=
false
;
}
};
};
...
...
paddle/parameter/ParameterUpdateFunctions.cpp
浏览文件 @
62d597c1
...
@@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate,
...
@@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate,
const
real
*
grad
,
const
real
*
grad
,
real
*
momentumVec
)
{
real
*
momentumVec
)
{
decayRate
*=
learningRate
;
decayRate
*=
learningRate
;
#ifdef PADDLE_USE_MKLDNN
#pragma omp parallel for
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
momentumVec
[
i
]
=
momentum
*
momentumVec
[
i
]
-
learningRate
*
grad
[
i
]
-
momentumVec
[
i
]
=
momentum
*
momentumVec
[
i
]
-
learningRate
*
grad
[
i
]
-
decayRate
*
value
[
i
];
decayRate
*
value
[
i
];
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
62d597c1
...
@@ -1566,7 +1566,7 @@ class LayerBase(object):
...
@@ -1566,7 +1566,7 @@ class LayerBase(object):
self
.
config
=
g_config
.
model_config
.
layers
.
add
()
self
.
config
=
g_config
.
model_config
.
layers
.
add
()
assert
isinstance
(
self
.
config
,
LayerConfig
)
assert
isinstance
(
self
.
config
,
LayerConfig
)
use_mkldnn
=
bool
(
int
(
g_command_config_args
.
get
(
"use_mkldnn"
,
0
)))
use_mkldnn
=
bool
(
int
(
g_command_config_args
.
get
(
"use_mkldnn"
,
0
)))
mkldnn_acts
=
[
'relu'
,
'tanh'
]
mkldnn_acts
=
[
'relu'
,
'tanh'
,
'softmax'
]
if
use_mkldnn
and
active_type
in
mkldnn_acts
:
if
use_mkldnn
and
active_type
in
mkldnn_acts
:
active_type
=
"mkldnn_"
+
active_type
active_type
=
"mkldnn_"
+
active_type
self
.
config
.
name
=
name
self
.
config
.
name
=
name
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
62d597c1
...
@@ -4,22 +4,24 @@ from op_test import OpTest
...
@@ -4,22 +4,24 @@ from op_test import OpTest
class
TestCrossEntropyOp1
(
OpTest
):
class
TestCrossEntropyOp1
(
OpTest
):
"""Test
standard cross-entropy, with index representation of
labels.
"""Test
cross-entropy with discrete one-hot
labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
30
batch_size
=
30
class_num
=
10
class_num
=
10
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
cross_entropy
=
np
.
asmatrix
(
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
False
}
self
.
attrs
=
{
"softLabel"
:
False
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest):
...
@@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest):
class
TestCrossEntropyOp2
(
OpTest
):
class
TestCrossEntropyOp2
(
OpTest
):
"""Test
soft-label cross-entropy, with vecte
rized soft labels.
"""Test
cross-entropy with vecto
rized soft labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
10
batch_size
=
5
class_num
=
5
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
...
@@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest):
...
@@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest):
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"softLabel"
:
True
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
class
TestCrossEntropyOp3
(
OpTest
):
class
TestCrossEntropyOp3
(
OpTest
):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
"""Test cross-entropy with vectorized one-hot representation of labels.
labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
30
batch_size
=
5
class_num
=
10
class_num
=
17
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
label_index
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
label
=
np
.
zeros
(
X
.
shape
)
label
=
np
.
zeros
(
X
.
shape
)
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
cross_entropy
=
np
.
asmatrix
(
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"softLabel"
:
True
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/v2/framework/tests/test_multiplex_op.py
浏览文件 @
62d597c1
...
@@ -6,20 +6,22 @@ from op_test import OpTest
...
@@ -6,20 +6,22 @@ from op_test import OpTest
class
TestMultiplexOp
(
OpTest
):
class
TestMultiplexOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"multiplex"
self
.
op_type
=
"multiplex"
rows
=
3
rows
=
4
index
=
np
.
array
([
3
,
1
,
0
])
index
=
np
.
arange
(
0
,
rows
).
astype
(
'int32'
)
np
.
random
.
shuffle
(
index
)
index
=
np
.
reshape
(
index
,
(
rows
,
1
))
ins1
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins1
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins2
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins2
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins3
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins3
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins4
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins4
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
self
.
inputs
=
{
self
.
inputs
=
{
'
X'
:
[(
'index'
,
index
),
(
'x1'
,
ins1
),
(
'x2'
,
ins2
),
(
'x3'
,
ins3
)
,
'
Ids'
:
index
,
(
'x4'
,
ins4
)]
'X'
:
[(
'x1'
,
ins1
),
(
'x2'
,
ins2
),
(
'x3'
,
ins3
),
(
'x4'
,
ins4
)]
}
}
# multiplex output
# multiplex output
output
=
np
.
zeros_like
(
ins1
)
output
=
np
.
zeros_like
(
ins1
)
for
i
in
range
(
0
,
rows
):
for
i
in
range
(
0
,
rows
):
k
=
index
[
i
]
+
1
k
=
index
[
i
]
[
0
]
output
[
i
]
=
self
.
inputs
[
'X'
][
k
][
1
][
i
]
output
[
i
]
=
self
.
inputs
[
'X'
][
k
][
1
][
i
]
self
.
outputs
=
{
'Out'
:
output
}
self
.
outputs
=
{
'Out'
:
output
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录