diff --git a/benchmark/paddle/image/run_mkldnn.sh b/benchmark/paddle/image/run_mkldnn.sh index 5b0a0373448e5b81ff0718db3465a4694690ec37..b6cd6fe03b381d2b6529116f934ce7ce03d63546 100755 --- a/benchmark/paddle/image/run_mkldnn.sh +++ b/benchmark/paddle/image/run_mkldnn.sh @@ -9,11 +9,9 @@ function train() { bs=$2 use_mkldnn=$3 if [ $3 == "True" ]; then - use_mkldnn=$3 thread=1 log="logs/${topology}-mkldnn-${bs}.log" elif [ $3 == "False" ]; then - use_mkldnn=$3 thread=`nproc` log="logs/${topology}-${thread}mklml-${bs}.log" else @@ -39,8 +37,7 @@ if [ ! -d "logs" ]; then mkdir logs fi -#========= mkldnn =========# -# vgg +#========== mkldnn ==========# train vgg 64 True train vgg 128 True train vgg 256 True diff --git a/doc/design/refactor/distributed_architecture.md b/doc/design/refactor/distributed_architecture.md new file mode 100644 index 0000000000000000000000000000000000000000..ac7e98ccf1aadbb973a4801fde842375cf63448c --- /dev/null +++ b/doc/design/refactor/distributed_architecture.md @@ -0,0 +1,222 @@ +# Design Doc: Distributed Training Architecture + +## Abstract + +PaddlePaddle v0.10.0 uses the "trainer-parameter server" +architecture. We run multiple replicated instances of trainers (runs +the same code written by the user) and parameter servers for +distributed training. This architecture served us well, but has some +limitations: + +1. Need to write special code to handle tasks which should only be run + by a single trainer. E.g., initializing model and saving model. + +2. Model parallelism is hard: need to write if-else branches conditioned + on the trainer ID to partition model onto each trainer, and manually + write the inter-model-shard communication code. + +3. The user can not directly specify the parameter update rule: need + to modify the parameter server C++ code and compile a new + binary. This adds complication for researchers: A lot of extra + effort is required. Besides, the training job submission program + may not allow running arbitrary binaries. + +This design doc discusses PaddlePaddle's new distributed training +architecture that addresses the above limitations. + +## Analysis + +We will assume the user writes the trainer program by Python, the same +analysis holds if the trainer program is written in C++. + +### Limitation 1 + +If we look at the Python code that the user writes, there are two +kinds of functionalities: + +- The training logic such as load / save model and print log. +- The neural network definition such as the definition of the data + layer, the fully connected layer, the cost function and the + optimizer. + +When we training with PaddlePaddle v0.10.0 distributedly, multiple +replicated Python instances are running on different nodes: both the +training logic and the neural network computation is replicated. + +The tasks that should only run once all belong to the training logic, +if we only replicate the neural network computation, but do **not** +replicate the training logic, the limitation could be solved. + +### Limitation 2 + +Model parallelism means running a single model on multiple nodes by +partitioning the model onto different nodes and managing the +inter-model-shard communications. + +PaddlePaddle should be able to modify the nerual network computation +definition to support model parallelism automatically. However, the +computation is only specified in Python code, and PaddlePaddle can not +modify Python code. + +Just like compiler uses a intermediate representation (IR) so that +programmer does not need to manually optimize their code in most of +the cases - the compiler will optimize the IR: + + + +We can have our own IR too: PaddlePaddle can support model parallel by +converting the IR so the user no longer need to manually do it in +Python: + + + +The IR for PaddlePaddle after refactor is called `Block`, it specifies +the computation dependency graph and the variables used in the +computation. + +### Limitation 3 + +The user can not directly specify the parameter update rule for the +parameter server because the parameter server does not use the same +computation definition as the trainer. Instead, the update rule is +baked in the parameter server. The user can not specify the update +rule in the same way of specifying the trainer computation. + +This could be fixed by making the parameter server run the same +computation definition as the trainer. For a detailed explanation, +please +see +[Design Doc: Operation Graph Based Parameter Server](./dist_train.md) + +## Distributed Training Architecture + +The new distributed training architecture can address the above +limitations. Below is the illustration: + + + +The architecture includes major components: *PaddlePaddle Python*, +*PaddlePaddle converter* and *PaddlePaddle runtime*: + +### PaddlePaddle Python + +PaddlePaddle Python is the Python library that user's Python trainer +invoke to build the neural network topology, start training, etc. + +```Python +paddle.init() +input = paddle.op.recordIO("/home/data/mnist.recordio") # file stored on the cluster +img, label = input[0], input[1] +hidden = paddle.layer.fc(input=img, size=200, act=paddle.activation.Tanh()) +prediction = paddle.layer.fc(input=img, size=10, act=paddle.activation.Softmax()) +cost = paddle.layer.classification_cost(input=prediction, label=label) +optimizer = paddle.optimizer.SGD(cost, learning_rate=0.01) +session = paddle.session.NewRemote(num_trainer=3, num_ps=2, GPU_per_trainer=1) +for i in range(1000): + _, cost_val = session.eval(targets=[cost, optimizer]) + print cost_val +``` + +The code above is a typical Python trainer code, the neural network +topology is built using helper functions such as +`paddle.layer.fc`. The training is done by calling `session.eval` +iteratively. + +#### session.eval + +As shown in the graph, `session.eval` sends the IR and the evaluation +inputs/targets to the PaddlePaddle cluster for evaluation. The +targets can be any variable in the computation graph. When the target +is the `optimizer` variable, the neural network will be optimized +once. When the target is the `cost` variable, `session.eval` returns +the cost value. + +The Python `session` is a wrapper of the C++ `Session` class. For more +information about `Session`, please +see [Design Doc: Session](./session.md). + +### PaddlePaddle Converter + +PaddlePaddle converter automatically converts the IR in the request +(IR and evaluation inputs/targets) from PaddlePaddle Python to new +partitioned IRs and dispatch the new IRs and evaluation inputs/targets +to different PaddlePaddle runtimes. Below are the steps: + +1. Add `feed` OP that feeds the eval inputs, and `fetch` OP that + fetches the eval targets to the IR. + +1. Extract a new computation (sub)graph with `feed` and `fetch` OP as + the boundary. The runtime does not need to run the OP that is not + dependent by the `fetch` OP. + +1. Optimizes the computation graph. + +1. Place the OPs in the graph onto different devices on different + PaddlePaddle runtime according to a placement algorithm and device + constraint specified by the user. + +1. Partition the graph according to runtime boundaries and add `send` / + `recv` OP pair on the runtime boundaries. + +1. Dispatch the partitioned graph to different PaddlePaddle runtimes. + +1. PaddlePaddle runtimes with the `fetch` OP reports evaluation + results back to the converter, the convert reports the evaluation + results back to the PaddlePaddle Python. + +The output IRs will be cached to optimize the conversion latency. + + +#### Placement Algorithm + +Our first implementation will only support "trainer-parameter server" +placement: the parameters, initializers, and optimizers are placed on +the PaddlePaddle runtimes with the parameter server role. And +everything else will be placed on the PaddlePaddle runtimes with the +trainer role. This has the same functionality of our +"trainer-parameter server" architecture of PaddlePaddle v0.10.0, but +is more general and flexible. + +In the future, we will implement the general placement algorithm, +which makes placements according to the input IR, and a model of +device computation time and device communication time. Model +parallelism requires the general placement algorithm. + + +### PaddlePaddle Runtime + +The PaddlePaddle runtime owns multiple devices (e.g., CPUs, GPUs) and +runs the IR. The runtime does not need to do OP placement since it's +already done by the converter. + + +### Local Training Architecture + +The local training architecture will be the same as the distributed +training architecture, the differences are everything runs locally, +and there is just one PaddlePaddle runtime: + + + + +### Training Data + +In PaddlePaddle v0.10.0, training data is typically read +with [data reader](../reader/README.md) from Python. This approach is +no longer efficient when training distributedly since the Python +process no longer runs on the same node with the trainer processes, +the Python reader will need to read from the distributed filesystem +(assuming it has the access) and send to the trainers, doubling the +network traffic. + +When doing distributed training, the user can still use Python data +reader: the training data are sent with `session.eval`. However should +be used for debugging purpose only. The users are encouraged to use +the read data OPs. + + +## References: + +[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) + +[2] [TensorFlow: A System for Large-Scale Machine Learning](https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf) diff --git a/doc/design/ops/dist_train.md b/doc/design/refactor/parameter_server.md similarity index 100% rename from doc/design/ops/dist_train.md rename to doc/design/refactor/parameter_server.md diff --git a/doc/design/refactor/src/compiler.graffle b/doc/design/refactor/src/compiler.graffle new file mode 100644 index 0000000000000000000000000000000000000000..8cc678fea3c820103e7ce81f7a5d625d6c1d92de Binary files /dev/null and b/doc/design/refactor/src/compiler.graffle differ diff --git a/doc/design/refactor/src/compiler.png b/doc/design/refactor/src/compiler.png new file mode 100644 index 0000000000000000000000000000000000000000..65d34f841afce9756def07dd8ecb9ca44e658bfe Binary files /dev/null and b/doc/design/refactor/src/compiler.png differ diff --git a/doc/design/ops/src/dist-graph.graffle b/doc/design/refactor/src/dist-graph.graffle similarity index 100% rename from doc/design/ops/src/dist-graph.graffle rename to doc/design/refactor/src/dist-graph.graffle diff --git a/doc/design/ops/src/dist-graph.png b/doc/design/refactor/src/dist-graph.png similarity index 100% rename from doc/design/ops/src/dist-graph.png rename to doc/design/refactor/src/dist-graph.png diff --git a/doc/design/refactor/src/distributed_architecture.graffle b/doc/design/refactor/src/distributed_architecture.graffle new file mode 100644 index 0000000000000000000000000000000000000000..f8496e57326c38de7468eb452a7713291d57653c Binary files /dev/null and b/doc/design/refactor/src/distributed_architecture.graffle differ diff --git a/doc/design/refactor/src/distributed_architecture.png b/doc/design/refactor/src/distributed_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..410c4510c6aab301dec95e6427fe80ac24e105fe Binary files /dev/null and b/doc/design/refactor/src/distributed_architecture.png differ diff --git a/doc/design/ops/src/local-graph.graffle b/doc/design/refactor/src/local-graph.graffle similarity index 100% rename from doc/design/ops/src/local-graph.graffle rename to doc/design/refactor/src/local-graph.graffle diff --git a/doc/design/ops/src/local-graph.png b/doc/design/refactor/src/local-graph.png similarity index 100% rename from doc/design/ops/src/local-graph.png rename to doc/design/refactor/src/local-graph.png diff --git a/doc/design/refactor/src/local_architecture.graffle b/doc/design/refactor/src/local_architecture.graffle new file mode 100644 index 0000000000000000000000000000000000000000..cc7783c45381f25ded0b898649322c81418ad317 Binary files /dev/null and b/doc/design/refactor/src/local_architecture.graffle differ diff --git a/doc/design/refactor/src/local_architecture.png b/doc/design/refactor/src/local_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..4b999538b7825c805292ee28b5e3256d5543bd09 Binary files /dev/null and b/doc/design/refactor/src/local_architecture.png differ diff --git a/doc/design/refactor/src/paddle-compile.graffle b/doc/design/refactor/src/paddle-compile.graffle new file mode 100644 index 0000000000000000000000000000000000000000..a6348cc3dbcaca923c6e794681b2edb85cb9f8f6 Binary files /dev/null and b/doc/design/refactor/src/paddle-compile.graffle differ diff --git a/doc/design/refactor/src/paddle-compile.png b/doc/design/refactor/src/paddle-compile.png new file mode 100644 index 0000000000000000000000000000000000000000..e0f13d551ac41afaec627a57dea79356464bf0bf Binary files /dev/null and b/doc/design/refactor/src/paddle-compile.png differ diff --git a/doc/faq/index_cn.rst b/doc/faq/index_cn.rst index b3ecfba791ead8349ded018a30059b03eacbdacd..d69d111917ca7a79bc65b051c8eefaba165d77bd 100644 --- a/doc/faq/index_cn.rst +++ b/doc/faq/index_cn.rst @@ -247,11 +247,11 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字 CMake Warning at cmake/version.cmake:20 (message): Cannot add paddle version from git tag - + 那么用户需要拉取所有的远程分支到本机,命令为 :code:`git fetch upstream`,然后重新cmake即可。 12. A protocol message was rejected because it was too big ----------------------------------------------------------- +------------------------------------------------------------ 如果在训练NLP相关模型时,出现以下错误: @@ -316,10 +316,42 @@ Paddle二进制在运行时捕获了浮点数异常,只要出现浮点数异 * 模型一直不收敛,发散到了一个数值特别大的地方。 * 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。 -主要的解决办法是减小学习率或者对数据进行归一化处理。 +这里有两种有效的解决方法: + +1. 设置 :code:`gradient_clipping_threshold` 参数,示例代码如下: + +.. code-block:: python + +optimizer = paddle.optimizer.RMSProp( + learning_rate=1e-3, + gradient_clipping_threshold=10.0, + regularization=paddle.optimizer.L2Regularization(rate=8e-4)) + +具体可以参考 `nmt_without_attention `_ 示例。 + +2. 设置 :code:`error_clipping_threshold` 参数,示例代码如下: + +.. code-block:: python + +decoder_inputs = paddle.layer.fc( + act=paddle.activation.Linear(), + size=decoder_size * 3, + bias_attr=False, + input=[context, current_word], + layer_attr=paddle.attr.ExtraLayerAttribute( + error_clipping_threshold=100.0)) + +完整代码可以参考示例 `machine translation `_ 。 + +两种方法的区别: + +1. 两者都是对梯度的截断,但截断时机不同,前者在 :code:`optimzier` 更新网络参数时应用;后者在激活函数反向计算时被调用; +2. 截断对象不同:前者截断可学习参数的梯度,后者截断回传给前层的梯度; + +除此之外,还可以通过减小学习律或者对数据进行归一化处理来解决这类问题。 15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2 ------------------------------------------------------------------------- +------------------------------------------------------------------------------------------ 先查看一下是否曾经安装过paddle v1版本,有的话需要先卸载: pip uninstall py_paddle paddle @@ -329,7 +361,7 @@ pip uninstall py_paddle paddle pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl 16. PaddlePaddle存储的参数格式是什么,如何和明文进行相互转化 ---------------------------------------------------------- +--------------------------------------------------------------------- PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数两部分组成。头信息中,1~4字节表示PaddlePaddle版本信息,请直接填充0;5~8字节表示每个参数占用的字节数,当保存的网络参数为float类型时为4,double类型时为8;9~16字节表示保存的参数总个数。 @@ -381,7 +413,7 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 parameters.set('emb', load_parameter(emb_param_file, 30000, 256)) 18. 集群多节点训练,日志中保存均为网络通信类错误 ------------------------------- +----------------------------------------------------------- 集群多节点训练,日志报错为网络通信类错误,比如 :code:`Connection reset by peer` 等。 此类报错通常是由于某一个节点的错误导致这个节点的训练进程退出,从而引发其他节点无法连接导致,可以参考下面的步骤排查: @@ -392,8 +424,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 * 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 -19. PaddlePaddle如何输出多个层 ------------------------------- +19. 如何调用 infer 接口输出多个layer的预测结果 +----------------------------------------------------------- * 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下: @@ -405,9 +437,28 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 .. code-block:: python - out = inferer.infer(input=data_batch, flatten_result=False, field=["value"]) + out = inferer.infer(input=data_batch, field=["value"]) + +需要注意的是: -这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。 +* 如果指定了2个layer作为输出层,实际上需要的输出结果是两个矩阵; +* 假设第一个layer的输出A是一个 N1 * M1 的矩阵,第二个 Layer 的输出B是一个 N2 * M2 的矩阵; +* paddle.v2 默认会将A和B 横向拼接,当N1 和 N2 大小不一样时,会报如下的错误: + +.. code-block:: python + + ValueError: all the input array dimensions except for the concatenation axis must match exactly + +多个层的输出矩阵的高度不一致导致拼接失败,这种情况常常发生在: + +* 同时输出序列层和非序列层; +* 多个输出层处理多个不同长度的序列; + +此时可以在调用infer接口时通过设置 :code:`flatten_result=False` , 跳过“拼接”步骤,来解决上面的问题。这时,infer接口的返回值是一个python list: + +* list 中元素的个数等于网络中输出层的个数; +* list 中每个元素是一个layer的输出结果矩阵,类型是numpy的ndarray; +* 每一个layer输出矩阵的高度,在非序列输入时:等于样本数;序列输入时等于:输入序列中元素的总数;宽度等于配置中layer的size; 20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用 ------------------------------------------------------------- @@ -416,8 +467,8 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 * PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。 -21. dropout 使用 ------------------ +21. 两种使用 drop_out 的方法有何区别? +----------------------------------------------------- * 在PaddlePaddle中使用dropout有两种方式 @@ -503,7 +554,7 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu optimizer = paddle.optimizer.Adam( learning_rate=1e-3, learning_rate_schedule="manual", - learning_rate_args="1:1.0,2:0.9,3:0.8",) + learning_rate_args="1:1.0,2:0.9,3:0.8",) 在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。 @@ -512,3 +563,30 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu 出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。 +24. PaddlePaddle 中不同的 recurrent layer 的区别 +-------------------------------------------------- +以LSTM为例,在PaddlePaddle中包含以下 recurrent layer: + +* :code:`paddle.layer.lstmemory` +* :code:`paddle.networks.simple_lstm` +* :code:`paddle.networks.lstmemory_group` +* :code:`paddle.networks.bidirectional_lstm` + +按照具体实现方式可以归纳为2类: + +1. 由 recurrent_group 实现的 recurrent layer: + + * 用户在使用这一类recurrent layer时,可以访问由recurrent unit在一个时间步内计算得到的中间值(例如:hidden states, memory cells等); + * 上述的 :code:`paddle.networks.lstmemory_group` 是这一类的 recurrent layer ; + +2. 将recurrent layer作为一个整体来实现: + + * 用户在使用这一类recurrent layer,只能访问它们的输出值; + * 上述的 :code:`paddle.networks.lstmemory_group` 、 :code:`paddle.networks.simple_lstm` 和 :code:`paddle.networks.bidirectional_lstm` 属于这一类的实现; + +将recurrent layer作为一个整体来实现, 能够针对CPU和GPU的计算做更多优化, 所以相比于recurrent group的实现方式, 第二类 recurrent layer 计算效率更高。 在实际应用中,如果用户不需要访问LSTM的中间变量,而只需要获得recurrent layer计算的输出,我们建议使用第二类实现。 + +此外,关于LSTM, PaddlePaddle中还包含 :code:`paddle.networks.lstmemory_unit` 这一计算单元: + + * 不同于上述介绍的recurrent layer , :code:`paddle.networks.lstmemory_unit` 定义了LSTM单元在一个时间步内的计算过程,它并不是一个完整的recurrent layer,也不能接收序列数据作为输入; + * :code:`paddle.networks.lstmemory_unit` 只能在recurrent_group中作为step function使用; diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index 84e33177740ca1652efc09c8081c2519b4366906..30b144d849bec367cd0197b6082889e011193a9a 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -20,7 +20,7 @@ Docker使用入门 docker pull paddlepaddle/paddle:0.10.0 - 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。 + 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用docker.paddlepaddle.org/paddle下载。 - *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。 实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。 diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp index ac50937ef3e28c1ac5aae651f9cf266ad07abcc4..18c5638100065109fba1f0647a1c5f91256f7b9d 100644 --- a/paddle/gserver/activations/MKLDNNActivation.cpp +++ b/paddle/gserver/activations/MKLDNNActivation.cpp @@ -27,31 +27,53 @@ static ClassRegistrar gMKLDNNActivationRegistrar; #define MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) mkldnn_##ACT_TYPE##Activation /** - * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION + * @def BEGIN_MKLDNN_ACTIVATION + */ +#define BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \ + class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) : public BASE_CLASS { +/** + * @def END_MKLDNN_ACTIVATION */ -#define DEFINE_MKLDNN_ELTWISE_ACTIVATION(ACT_TYPE, ALPHA, BWD_ALPHA) \ - class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) \ - : public MKLDNNEltwiseActivation { \ - private: \ - static const std::string name; \ - static const float alpha; \ - static const float bwdAlpha; \ - \ - public: \ - const std::string& getName() const { return name; } \ - float getAlpha() const { return alpha; } \ - float getBwdAlpha() const { return bwdAlpha; } \ - }; \ - const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \ - "mkldnn_" #ACT_TYPE; \ - const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \ - const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA; \ - static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \ - gMKLDNNActivationRegistrar \ - .registerClass( \ - "mkldnn_" #ACT_TYPE); \ +#define END_MKLDNN_ACTIVATION(ACT_TYPE) \ +private: \ + static const std::string name; \ + \ +public: \ + const std::string& getName() const { return name; } \ + } \ + ; \ + const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \ + "mkldnn_" #ACT_TYPE; \ + static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \ + gMKLDNNActivationRegistrar \ + .registerClass( \ + "mkldnn_" #ACT_TYPE); \ }); +/** + * @def DEFINE_MKLDNN_ACTIVATION + */ +#define DEFINE_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \ + BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \ + END_MKLDNN_ACTIVATION(ACT_TYPE) + +/** + * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION + */ +#define DEFINE_MKLDNN_ELTWISE_ACTIVATION( \ + ACT_TYPE, BASE_CLASS, ALPHA, BWD_ALPHA) \ + BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \ +private: \ + static const float alpha; \ + static const float bwdAlpha; \ + \ +public: \ + float getAlpha() const { return alpha; } \ + float getBwdAlpha() const { return bwdAlpha; } \ + END_MKLDNN_ACTIVATION(ACT_TYPE) \ + const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \ + const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA; + /** * @brief MKLDNN Relu Activation. * Actually mkldnn_relu is Leaky Relu. @@ -59,19 +81,129 @@ static ClassRegistrar gMKLDNNActivationRegistrar; * f(x) = negative_slope * x (x < 0) * @note the negative_slope should be -0.f in forward */ -DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, -0.f, 0.f) +DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, MKLDNNEltwiseActivation, -0.f, 0.f) /** * @brief MKLDNN Tanh Activation. */ -DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, 0.f, 0.f) +DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, MKLDNNEltwiseActivation, 0.f, 0.f) /** * @brief MKLDNN ELU(Exponential Linear Unit) Activation. * f(x) = x (x >= 0) * f(x) = negative_slope * (exp(x) - 1) (x < 0) */ -DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, 0.f, 0.f) +DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, MKLDNNEltwiseActivation, 0.f, 0.f) + +mkldnn::algorithm MKLDNNEltwiseActivation::getAlgo(std::string type) const { + const std::map algoMap = { + {"relu", algorithm::eltwise_relu}, + {"tanh", algorithm::eltwise_tanh}, + {"elu", algorithm::eltwise_elu}}; + type.erase(0, 7); // remove mkldnn_ + algorithm algo = (algorithm)0; + mapGet(type, algoMap, &algo); + return algo; +} + +void MKLDNNEltwiseActivation::resetFwd(Argument& act) { + if (cnt_ == act.value->getElementCnt()) { + return; + } + MKLDNNActivation::resetFwd(act); + // note: alpha represents the NegativeSlope when used in relu. + float alpha = getAlpha(); + float beta = getBeta(); + algorithm algo = getAlgo(this->getName()); + auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training, + algo, + val_->getMemoryDesc(), + alpha, + beta); + fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, *engine_)); + // use inplace for forward but save input value before submit + inVal_ = val_; + copyInVal_ = nullptr; + if (act.grad && algo == algorithm::eltwise_tanh) { + // tanh need save src input for backward + inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc()); + copyInVal_ = std::make_shared(*val_, *inVal_); + CHECK(copyInVal_) << "should not be emptry"; + pipelineFwd_.push_back(*copyInVal_); + } + fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_)); + pipelineFwd_.push_back(*fwd_); + needResetBwd_ = true; +} + +void MKLDNNEltwiseActivation::resetBwd(Argument& act) { + if (!needResetBwd_) { + return; + } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; + needResetBwd_ = false; + algorithm algo = getAlgo(this->getName()); + float alpha = getBwdAlpha(); + float beta = getBeta(); + grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc()); + auto eng = CPUEngine::Instance().getEngine(); + auto bwdDesc = eltwise_bwd::desc( + algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta); + auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_); + CHECK(inVal_); + bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_)); + pipelineBwd_.clear(); + pipelineBwd_.push_back(*bwd_); +} + +/** + * @brief MKLDNN Softmax Activation + */ +DEFINE_MKLDNN_ACTIVATION(softmax, MKLDNNSoftmaxActivation) + +void MKLDNNSoftmaxActivation::resetFwd(Argument& act) { + if (cnt_ == act.value->getElementCnt()) { + return; + } + MKLDNNActivation::resetFwd(act); + int axis = 1; + auto fwdDesc = softmax_fwd::desc( + mkldnn::prop_kind::forward_scoring, val_->getMemoryDesc(), axis); + auto fwdPD = softmax_fwd::primitive_desc(fwdDesc, *engine_); + fwd_.reset(new softmax_fwd(fwdPD, *val_, *val_)); + pipelineFwd_.push_back(*fwd_); +} + +Error __must_check MKLDNNSoftmaxActivation::forward(Argument& act) { + resetFwd(act); + stream_->submit(pipelineFwd_); + real* v = act.value->getData(); + real threshold = exp(-64); +#pragma omp parallel for + for (size_t i = 0; i < act.value->getElementCnt(); ++i) { + v[i] = v[i] < threshold ? threshold : v[i]; + } + return Error(); +} + +Error __must_check MKLDNNSoftmaxActivation::backward(Argument& act) { + MatrixPtr outputV = act.value; + MatrixPtr outputG = act.grad; + Matrix::resizeOrCreate(sftMaxDot_, + outputG->getHeight(), + outputG->getWidth(), + /* trans */ false, + /* useGpu */ false); + Matrix::resizeOrCreate(sftMaxSum_, + outputG->getHeight(), + 1, + /* trans */ false, + /* useGpu */ false); + sftMaxDot_->dotMul(*outputG, *outputV); + sftMaxSum_->colMerge(*sftMaxDot_); + act.grad->softmaxDerivative(*act.value, *sftMaxSum_); + return Error(); +} ActivationFunction* MKLDNNActivation::create(const std::string& type) { return gMKLDNNActivationRegistrar.createByType(type); @@ -84,4 +216,34 @@ std::vector MKLDNNActivation::getAllRegisteredTypes() { return types; } +void MKLDNNActivation::resetFwd(Argument& act) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; + cnt_ = act.value->getElementCnt(); + pipelineFwd_.clear(); + stream_.reset(new MKLDNNStream()); + engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0)); + val_ = std::dynamic_pointer_cast(act.value); + if (val_ == nullptr) { + int bs = act.getBatchSize(); + int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1; + int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1; + int ic = cnt_ / bs / ih / iw; + CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw); + val_ = MKLDNNMatrix::create( + act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_); + CHECK(val_); + val_->downSpatial(); + } +} + +Error __must_check MKLDNNActivation::forward(Argument& act) { + resetFwd(act); + stream_->submit(pipelineFwd_); + return Error(); +} +Error __must_check MKLDNNActivation::backward(Argument& act) { + resetBwd(act); + stream_->submit(pipelineBwd_); + return Error(); +} } // namespace paddle diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h index 40dd8c618aa2b70d410130e12efc54520218afea..dd16421fd6e93b49c30b1d3b601f95980afec57b 100644 --- a/paddle/gserver/activations/MKLDNNActivation.h +++ b/paddle/gserver/activations/MKLDNNActivation.h @@ -36,6 +36,7 @@ protected: // mkldnn matrix, primitive, stream and pipeline MKLDNNMatrixPtr val_; MKLDNNMatrixPtr grad_; + std::shared_ptr engine_; std::shared_ptr stream_; std::shared_ptr fwd_; std::shared_ptr bwd_; @@ -48,8 +49,18 @@ public: static ActivationFunction* create(const std::string& type); static std::vector getAllRegisteredTypes(); virtual const std::string& getName() const = 0; - virtual Error __must_check forward(Argument& act) = 0; - virtual Error __must_check backward(Argument& act) = 0; + /** + * reset the forward primitives + */ + virtual void resetFwd(Argument& act); + /** + * reset the backward primitives, + * can not merge this functions into resetFwd as the grad data + * would be changing before backward. + */ + virtual void resetBwd(Argument& act) {} + virtual Error __must_check forward(Argument& act); + virtual Error __must_check backward(Argument& act); }; /** @@ -59,6 +70,7 @@ public: class MKLDNNEltwiseActivation : public MKLDNNActivation { typedef mkldnn::eltwise_forward eltwise_fwd; typedef mkldnn::eltwise_backward eltwise_bwd; + typedef mkldnn::algorithm algorithm; protected: // save the forward primitive desc, which can be used backward @@ -70,9 +82,7 @@ protected: public: MKLDNNEltwiseActivation() {} - ~MKLDNNEltwiseActivation() {} - virtual const std::string& getName() const = 0; // in common, the alpha of forward and backward should be equal. @@ -80,105 +90,30 @@ public: virtual float getAlpha() const = 0; virtual float getBwdAlpha() const = 0; virtual float getBeta() const { return 0.f; } - virtual mkldnn::algorithm getAlgo(const std::string& type) const { - if (type == "mkldnn_relu") { - return mkldnn::algorithm::eltwise_relu; - } else if (type == "mkldnn_tanh") { - return mkldnn::algorithm::eltwise_tanh; - } else if (type == "mkldnn_elu") { - return mkldnn::algorithm::eltwise_elu; - } else { - LOG(FATAL) << "Unkown eltwise activation type: " << type; - } - return (mkldnn::algorithm)0; - } - - /** - * reshape and reset the forward primitives - */ - void resetFwd(Argument& act) { - if (cnt_ == act.value->getElementCnt()) { - return; - } - VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; - cnt_ = act.value->getElementCnt(); - stream_.reset(new MKLDNNStream()); - auto eng = CPUEngine::Instance().getEngine(); - - // get algo setting - mkldnn::algorithm algo = getAlgo(this->getName()); - // note: alpha represents the NegativeSlope when used in relu. - float alpha = getAlpha(); - float beta = getBeta(); - - pipelineFwd_.clear(); - val_ = std::dynamic_pointer_cast(act.value); - if (val_ == nullptr) { - int bs = act.getBatchSize(); - int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1; - int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1; - int ic = cnt_ / bs / ih / iw; - CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw); - val_ = MKLDNNMatrix::create( - act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng); - CHECK(val_); - } - auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training, - algo, - val_->getMemoryDesc(), - alpha, - beta); - fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng)); - // use inplace for forward but save input value before submit - inVal_ = val_; - copyInVal_ = nullptr; - if (act.grad && algo == mkldnn::algorithm::eltwise_tanh) { - // tanh need save src input for backward - inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc()); - copyInVal_ = std::make_shared(*val_, *inVal_); - CHECK(copyInVal_) << "should not be emptry"; - pipelineFwd_.push_back(*copyInVal_); - } - fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_)); - pipelineFwd_.push_back(*fwd_); - needResetBwd_ = true; - } + virtual algorithm getAlgo(std::string type) const; + void resetFwd(Argument& act) override; + void resetBwd(Argument& act) override; +}; - /** - * reset the backward primitives, can not merge into resetFwd as the grad data - * would be changing before backward. - */ - void resetBwd(Argument& act) { - if (!needResetBwd_) { - return; - } - VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; - needResetBwd_ = false; - mkldnn::algorithm algo = getAlgo(this->getName()); - float alpha = getBwdAlpha(); - float beta = getBeta(); - grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc()); - auto eng = CPUEngine::Instance().getEngine(); - auto bwdDesc = eltwise_bwd::desc( - algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta); - auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_); - CHECK(inVal_); - bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_)); - pipelineBwd_.clear(); - pipelineBwd_.push_back(*bwd_); - } +/** + * @brief Base class of MKLDNN softmax Activation, + * only have mkldnn forward, use cpu implement for backward. + */ +class MKLDNNSoftmaxActivation : public MKLDNNActivation { + typedef mkldnn::softmax_forward softmax_fwd; - Error __must_check forward(Argument& act) { - resetFwd(act); - stream_->submit(pipelineFwd_); - return Error(); - } +private: + // for backward + MatrixPtr sftMaxSum_; + MatrixPtr sftMaxDot_; - Error __must_check backward(Argument& act) { - resetBwd(act); - stream_->submit(pipelineBwd_); - return Error(); - } +public: + MKLDNNSoftmaxActivation() {} + ~MKLDNNSoftmaxActivation() {} + virtual const std::string& getName() const = 0; + void resetFwd(Argument& act) override; + Error __must_check forward(Argument& act) override; + Error __must_check backward(Argument& act) override; }; } // namespace paddle diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index 1bfbbde4246a10eaf86693a6a2f237f390966db3..857d07df3e3088be28943d9e2fe58017e9e57f4a 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -222,8 +222,8 @@ static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) { } void testActivation(std::string& actType, const testActDesc& pm) { - // TODO(TJ): mkldnn_softmax not implemented, paddle do not have elu activation - if (actType == "mkldnn_softmax" || actType == "mkldnn_elu") { + // TODO(TJ): remove me when paddle support elu activation + if (actType == "mkldnn_elu") { return; } const std::string compareTypes[] = {actType, actType.erase(0, 7)}; diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( - num_samples, infer_width, inference_data, label_data, accuracy_data); + AccuracyCudaKernel<<< + 1, PADDLE_CUDA_NUM_THREADS, 0, + reinterpret_cast( + ctx.device_context()) + .stream()>>>(num_samples, infer_width, inference_data, label_data, + accuracy_data); } }; diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index b11dc1472d153dd188a0b3553d6950774216a3fd..2e16201e74c153888594ebe6679fb0036734dad4 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + "Input(Label) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - if (ctx.Attr("soft_label")) { + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "If Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "If Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } ctx.Output("Y")->Resize({x->dims()[0], 1}); @@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); + "Input(Label) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) must not be null."); + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, + "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], - "The 1st dimension of Input(X) and Input(Y@Grad) must " + "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); PADDLE_ENFORCE_EQ(dy->dims()[1], 1, - "The 2nd dimension of Input(Y@Grad) must be 1."); - if (ctx.Attr("soft_label")) { + "The 2nd dimension of Input(Y@Grad) should be 1."); + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "When Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "When Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } auto dx = ctx.Output(framework::GradVarName("X")); @@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { CrossEntropyOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of CrossEntropyOp"); - AddInput("Label", "The second input of CrossEntropyOp"); - AddOutput("Y", "The output of CrossEntropyOp"); - AddAttr("soft_label", "Is soft label. Default zero.") + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor, default Tensor), the ground truth which is " + "a 2-D tensor. " + "When softLabel is set to false, `Label` is a Tensor with shape " + "[N x 1]. " + "When softLabel is set to true, `Label` is a Tensor " + "with shape [N x K]."); + AddOutput("Y", + "(Tensor, default Tensor), a 2-D tensor " + "with shape [N x 1]. The cross entropy loss."); + AddAttr( + "softLabel", + "(bool, default false), a flag to indicate whether to interpretate " + "the given labels as soft labels.") .SetDefault(false); - AddComment(R"DOC( CrossEntropy Operator. It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: - soft_label = False, Label[i, 0] indicates the class index for sample i: + softLabel = false, Label[i, 0] indicates the class index for sample i: Y[i] = -log(X[i, Label[i]]) 2) Soft-label cross-entropy: - soft_label = True, Label[i, j] indicates the soft label of class j + softLabel = true, Label[i, j] indicates the soft label of class j for sample i: Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 1d6361a81472a49729958120c52060b1dff803f2..18e44d77c9f62b296dc57952e546f844670c7d57 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -tolerable_value(log(X[i * D + label[i]])); + Y[i] = -TolerableValue()(log(X[i * D + label[i]])); } } +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int N, const int D) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - T sum = static_cast(0); - for (int j = 0; j < D; j++) { - sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); - } - Y[i] = -sum; + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; } -// TODO(qingqing): make zero setting an common function. +// TODO(qingqing): make zero setting a common function. template -__global__ void zero(T* X, const int N) { +__global__ void Zero(T* X, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { X[i] = 0.0; @@ -71,13 +94,10 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { - // TOOD(qingqing): optimize for this kernel - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - for (int j = 0; j < D; ++j) { - int idx = i * D + j; - dX[idx] = -label[idx] * dY[i] / X[idx]; - } + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < N * D) { + int row_ids = ids / D; + dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } @@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - auto label = ctx.Input("Label"); + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); + const T* x_data = x->data(); + T* y_data = y->mutable_data(ctx.GetPlace()); - int n = x->dims()[0]; - int d = x->dims()[1]; - int block = 512; - int grid = (n + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + if (ctx.Attr("softLabel")) { auto* label_data = ctx.Input("Label")->data(); - SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, - d); + int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, class_num); } else { auto* label_data = ctx.Input("Label")->data(); - CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + int block = 512; + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, + batch_size, class_num); } } }; @@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); + + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); + const T* dy_data = + ctx.Input(framework::GradVarName("Y"))->data(); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* x_data = x->data(); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; - int n = x->dims()[0]; - int d = x->dims()[1]; int block = 512; - int grid = (n * d + block - 1) / block; - zero<<>>(dx_data, n * d); - grid = (n + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int grid = (batch_size * class_num + block - 1) / block; + + if (ctx.Attr("softLabel")) { auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, n, d); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } else { + Zero<<( + ctx.device_context()) + .stream()>>>(dx_data, batch_size * class_num); + auto* label_data = label->data(); - CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, - label_data, n, d); + grid = (batch_size + block - 1) / block; + CrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 69caba5ff31f60df2c24cef0e6331f058f6ba8d6..255b2e9f5ea7566cca7fd3914e38da804b7c7006 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/hostdevice.h" @@ -20,53 +21,51 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; template -HOSTDEVICE T tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; } - if (x == -INFINITY) { - return -kApproInf; - } - return x; -} +}; template class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; - - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - T sum = static_cast(0); - for (int j = 0; j < class_num; ++j) { - sum += label_data[index] * tolerable_value(std::log(x_data[index])); - y_data[i] = -sum; - index++; - } - } + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* labels = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); + T* y_data = y->mutable_data(ctx.GetPlace()); + + const int batch_size = x->dims()[0]; + if (ctx.Attr("softLabel")) { + auto prob = EigenMatrix::From(*x); + auto lbl_mat = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*y); + + loss.device(ctx.GetEigenDevice()) = + -((lbl_mat * prob.log().unaryExpr(TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); } else { - auto* label_data = ctx.Input("Label")->data(); + const int class_num = x->dims()[1]; + const T* x_data = x->data(); + + const int* label_data = labels->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - y_data[i] = -tolerable_value(std::log(x_data[index])); + y_data[i] = -TolerableValue()(std::log(x_data[index])); } } } @@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* dy = ctx.Input(framework::GradVarName("Y")); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); - - int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - - // TODO(qingqing): make zero setting an common function. - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - for (int j = 0; j < class_num; ++j) { - dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; - index++; - } - } + if (ctx.Attr("softLabel")) { + auto x_mat = EigenMatrix::From(*x); + auto dy_mat = EigenMatrix::From(*dy); + auto lbl_mat = EigenMatrix::From(*label); + auto dx_mat = EigenMatrix::From(*dx); + + dx_mat.device(ctx.GetEigenDevice()) = + -(lbl_mat * dy_mat.broadcast(Eigen::DSizes(1, class_num)) / + x_mat); } else { - auto* label_data = label->data(); + int batch_size = x->dims()[0]; + const T* dy_data = dy->data(); + const T* x_data = x->data(); + const int* label_data = label->data(); + + // TODO(qingqing): make zero setting a common function. memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); int index = i * class_num + label_data[i]; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 708344046760691aa2da562eb1ee3d8b130c5f18..62f63b4f3c876e084e2468001e8bcb9310d16a82 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable<<>>(output, table, ids, N, K, D); + LookupTable<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(output, table, ids, N, K, D); } }; @@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGrad<<>>(d_table, d_output, ids, N, - K, D); + LookupTableGrad<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(d_table, d_output, ids, N, K, D); } }; diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6e77b86b5698a263b850a973cd1b8644a0aa2201..7b50444d16dc57fd14b918d1159e3e21ecd1f1c4 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), + "Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "MultiInput(X) shouldn't be empty."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) shouldn't be null."); + auto ids_dim = ctx.Input("Ids")->dims(); + PADDLE_ENFORCE( + ids_dim.size() == 2 && ids_dim[1] == 1, + "The index tensor must be a vector with size batchSize x 1."); + auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); - PADDLE_ENFORCE(num_ins > 2, - "multiplex operator should have more than 2 inputs."); - PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, - "The first input must be a index vector."); - auto in_dim = ins[1]->dims(); - - for (size_t i = 2; i < num_ins; i++) { + PADDLE_ENFORCE(num_ins > 1, + "multiplex operator should have more than " + "one candidate input tensors."); + + auto in_dim = ins[0]->dims(); + PADDLE_ENFORCE(in_dim.size() >= 2, + "The rank of candidate tensors must be not less than 2."); + for (size_t i = 1; i < num_ins; i++) { auto dim = ins[i]->dims(); - PADDLE_ENFORCE( - in_dim == dim, - "All the input tensors except the first one must have the same size"); + PADDLE_ENFORCE(in_dim == dim, + "All the candidate tensors must have the same size."); } out->Resize(in_dim); } @@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { MultiplexOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); + AddInput("Ids", "The index tensor of multiplex operator."); + AddInput("X", "The candidate tensors of multiplex operator.") + .AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator -Multiplex multiple tensors according to the index provided by the first -input tensor. +Multiplex multiple tensors according to the index provided by the index tensor. -ins[0]: the index tensor. -ins[1:N]: the candidate output tensors. +Ids: the index tensor. +X[0 : N - 1]: the candidate tensors for output (N >= 2). For each index i from 0 to batchSize - 1, the output is the i-th row of the -the (index[i] + 1)-th tensor. +the (Ids[i])-th tensor. For i-th row of the output tensor: -y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) +y[i] = x_{k}[i] where y is the output tensor. `x_{k}` is the k-th input tensor -and `k = x{0}[i] + 1`. - +and `k = Ids[i]`. )DOC"); } }; @@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), - "Output(X@Grad) should not be null"); + "Output(X@Grad) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + "Input(Out@GRAD) should not be null."); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - // don't compute gradient for index (ins[0]) - for (size_t i = 1; i < ins.size(); i++) { + // No need to compute gradient for Input(Ids) + for (size_t i = 0; i < ins.size(); i++) { if (d_ins[i]) { d_ins[i]->Resize(ins[i]->dims()); } diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 4736f15bd594178168e3bcf799142d0fc18bff13..70e46815fc9148a2530d437d20c14f5d40baa1a4 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,27 +18,30 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - + auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -51,11 +54,11 @@ template class MultiplexGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto d_ins = - ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T), stream); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 98466426bd90bc30a22ecf74e6739e2d4ad1d21d..637c63a34af394f5f54997c46c00a9ff00577476 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto ids = ctx.Input("Ids"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + auto index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); PADDLE_ENFORCE_LT(static_cast(k), ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* ids = ctx.Input("Ids"); auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + auto* index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T)); diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu index afe4d149c53819c45e20353bc9d16393f3f61e0f..53fe505b77bfac8a33803f082f8e935d3ed403b6 100644 --- a/paddle/operators/top_k_op.cu +++ b/paddle/operators/top_k_op.cu @@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. - // TODO(typhoonzero): launch kernel on specified stream. // TODO(typhoonzero): refine this kernel. dim3 threads(256, 1); dim3 grid(input_height, 1); - KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, int(k)); + KeMatrixTopK<<< + grid, threads, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(output_data, output->dims()[1], + indices_data, input_data, + input_width, input_width, int(k)); } }; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index caa78acd98ea4b35fc69643689cfce23026275e0..895e8d6a63d1fad0ee7a6f5647402435d418b2f1 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -37,6 +38,15 @@ public: real torch_learningRate = optConfig_.learning_method() == "torch_momentum" ? 1.0 - paraConfig.momentum() : 1.0; +#ifdef PADDLE_USE_MKLDNN + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); +#else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM], @@ -44,6 +54,7 @@ public: (firstTime_ ? 1.0 : torch_learningRate), paraConfig.momentum(), applyDecay_ ? paraConfig.decay_rate() : 0); +#endif } virtual void finishBatch() { firstTime_ = false; } }; diff --git a/paddle/parameter/ParameterUpdateFunctions.cpp b/paddle/parameter/ParameterUpdateFunctions.cpp index c8af7105c78dcbf9f625a348b7f38efcf278469e..8b3be062b654a52e667626199be8c8bb4a2a96d7 100644 --- a/paddle/parameter/ParameterUpdateFunctions.cpp +++ b/paddle/parameter/ParameterUpdateFunctions.cpp @@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate, const real* grad, real* momentumVec) { decayRate *= learningRate; +#ifdef PADDLE_USE_MKLDNN +#pragma omp parallel for +#endif for (size_t i = 0; i < size; ++i) { momentumVec[i] = momentum * momentumVec[i] - learningRate * grad[i] - decayRate * value[i]; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 0f57b81966647ca5c6f5cd2e5518d2d34942a549..098a51ab8791290d3e0ffa2c3703c724dd2387b9 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1566,7 +1566,7 @@ class LayerBase(object): self.config = g_config.model_config.layers.add() assert isinstance(self.config, LayerConfig) use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0))) - mkldnn_acts = ['relu', 'tanh'] + mkldnn_acts = ['relu', 'tanh', 'softmax'] if use_mkldnn and active_type in mkldnn_acts: active_type = "mkldnn_" + active_type self.config.name = name diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index f10db783225c07be9ffde25267fdfe096e97ecac..1de514dff487158e0823fd628d9b3b50f36fdd9b 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -4,22 +4,24 @@ from op_test import OpTest class TestCrossEntropyOp1(OpTest): - """Test standard cross-entropy, with index representation of labels. + """Test cross-entropy with discrete one-hot labels. """ def setUp(self): self.op_type = "cross_entropy" batch_size = 30 class_num = 10 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") cross_entropy = np.asmatrix( [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], dtype="float32") + self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} - self.attrs = {'soft_label': False} + self.attrs = {"softLabel": False} def test_check_output(self): self.check_output() @@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp2(OpTest): - """Test soft-label cross-entropy, with vecterized soft labels. + """Test cross-entropy with vectorized soft labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 10 - class_num = 5 + batch_size = 5 + class_num = 37 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.uniform(0.1, 1.0, @@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest): label /= label.sum(axis=1, keepdims=True) cross_entropy = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) class TestCrossEntropyOp3(OpTest): - """Test one-hot cross-entropy, with vecterized one-hot representation of - labels. + """Test cross-entropy with vectorized one-hot representation of labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 30 - class_num = 10 + batch_size = 5 + class_num = 17 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label_index = np.random.randint( 0, class_num, (batch_size), dtype="int32") label = np.zeros(X.shape) label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], dtype="float32") cross_entropy2 = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py index f2b3881cde24c7fb96c3d7f9411352bc62d55077..5937eb5aa4621556c9b8d59ea83a39d9738c7925 100644 --- a/python/paddle/v2/framework/tests/test_multiplex_op.py +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -6,20 +6,22 @@ from op_test import OpTest class TestMultiplexOp(OpTest): def setUp(self): self.op_type = "multiplex" - rows = 3 - index = np.array([3, 1, 0]) + rows = 4 + index = np.arange(0, rows).astype('int32') + np.random.shuffle(index) + index = np.reshape(index, (rows, 1)) ins1 = np.random.random((rows, 10)).astype("float32") ins2 = np.random.random((rows, 10)).astype("float32") ins3 = np.random.random((rows, 10)).astype("float32") ins4 = np.random.random((rows, 10)).astype("float32") self.inputs = { - 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), - ('x4', ins4)] + 'Ids': index, + 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)] } # multiplex output output = np.zeros_like(ins1) for i in range(0, rows): - k = index[i] + 1 + k = index[i][0] output[i] = self.inputs['X'][k][1][i] self.outputs = {'Out': output}