diff --git a/benchmark/paddle/image/provider.py b/benchmark/paddle/image/provider.py index 1ac47212b5a75667e8e9d4465b33f575516e2836..4703944c8722552d56ba80a8e0663de5fb4df53d 100644 --- a/benchmark/paddle/image/provider.py +++ b/benchmark/paddle/image/provider.py @@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs): def process(settings, file_list): for i in xrange(1024): img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten() - lab = random.randint(0, settings.num_class) + lab = random.randint(0, settings.num_class - 1) yield img.astype('float32'), int(lab) diff --git a/benchmark/paddle/image/run_mkldnn.sh b/benchmark/paddle/image/run_mkldnn.sh new file mode 100755 index 0000000000000000000000000000000000000000..5b0a0373448e5b81ff0718db3465a4694690ec37 --- /dev/null +++ b/benchmark/paddle/image/run_mkldnn.sh @@ -0,0 +1,51 @@ +set -e + +unset OMP_NUM_THREADS MKL_NUM_THREADS +export OMP_DYNAMIC="FALSE" +export KMP_AFFINITY="granularity=fine,compact,0,0" + +function train() { + topology=$1 + bs=$2 + use_mkldnn=$3 + if [ $3 == "True" ]; then + use_mkldnn=$3 + thread=1 + log="logs/${topology}-mkldnn-${bs}.log" + elif [ $3 == "False" ]; then + use_mkldnn=$3 + thread=`nproc` + log="logs/${topology}-${thread}mklml-${bs}.log" + else + echo "Wrong input $3, use True or False." + fi + args="batch_size=${bs}" + config="${topology}.py" + paddle train --job=time \ + --config=$config \ + --use_mkldnn=$use_mkldnn \ + --use_gpu=False \ + --trainer_count=$thread \ + --log_period=10 \ + --test_period=100 \ + --config_args=$args \ + 2>&1 | tee ${log} +} + +if [ ! -d "train.list" ]; then + echo " " > train.list +fi +if [ ! -d "logs" ]; then + mkdir logs +fi + +#========= mkldnn =========# +# vgg +train vgg 64 True +train vgg 128 True +train vgg 256 True + +#========== mklml ===========# +train vgg 64 False +train vgg 128 False +train vgg 256 False diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..b8429975f5c83df6996e71478fe276b246e8b77b --- /dev/null +++ b/benchmark/paddle/image/vgg.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +from paddle.trainer_config_helpers import * + +height = 224 +width = 224 +num_class = 1000 +batch_size = get_config_arg('batch_size', int, 64) +layer_num = get_config_arg('layer_num', int, 19) + +args = {'height': height, 'width': width, 'color': True, 'num_class': num_class} +define_py_data_sources2( + "train.list", None, module="provider", obj="process", args=args) + +settings( + batch_size=batch_size, + learning_rate=0.01 / batch_size, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * batch_size)) + +img = data_layer(name='image', size=height * width * 3) + + +def vgg_network(vgg_num=3): + tmp = img_conv_group( + input=img, + num_channels=3, + conv_padding=1, + conv_num_filter=[64, 64], + conv_filter_size=3, + conv_act=ReluActivation(), + pool_size=2, + pool_stride=2, + pool_type=MaxPooling()) + + tmp = img_conv_group( + input=tmp, + conv_num_filter=[128, 128], + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + channels = [] + for i in range(vgg_num): + channels.append(256) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + channels = [] + for i in range(vgg_num): + channels.append(512) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation()) + + +if layer_num == 16: + vgg = vgg_network(3) +elif layer_num == 19: + vgg = vgg_network(4) +else: + print("Wrong layer number.") + +lab = data_layer('label', num_class) +loss = cross_entropy(input=vgg, label=lab) +outputs(loss) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 0bbf92293168d4e3af2c1ed0e82b75e6a8d6c0cd..ff9868fc4e0d970b11e4763d2e0c8581f4f85907 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -253,7 +253,7 @@ function(nv_library TARGET_NAME) foreach(source_file ${nv_library_SRCS}) string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) - list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND nv_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) endif() endforeach() add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS}) diff --git a/cmake/util.cmake b/cmake/util.cmake index ac911052eb970c5a3e485e3178dd788b1517ca30..d1aee3e170a2d143ac06b438725e907e96f041c8 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME) target_link_libraries(${TARGET_NAME} log) endif(ANDROID) + if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR) + target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") + endif() + add_dependencies(${TARGET_NAME} ${external_project_dependencies}) endfunction() diff --git a/doc/design/refactor/distributed_architecture.md b/doc/design/refactor/distributed_architecture.md new file mode 100644 index 0000000000000000000000000000000000000000..ac7e98ccf1aadbb973a4801fde842375cf63448c --- /dev/null +++ b/doc/design/refactor/distributed_architecture.md @@ -0,0 +1,222 @@ +# Design Doc: Distributed Training Architecture + +## Abstract + +PaddlePaddle v0.10.0 uses the "trainer-parameter server" +architecture. We run multiple replicated instances of trainers (runs +the same code written by the user) and parameter servers for +distributed training. This architecture served us well, but has some +limitations: + +1. Need to write special code to handle tasks which should only be run + by a single trainer. E.g., initializing model and saving model. + +2. Model parallelism is hard: need to write if-else branches conditioned + on the trainer ID to partition model onto each trainer, and manually + write the inter-model-shard communication code. + +3. The user can not directly specify the parameter update rule: need + to modify the parameter server C++ code and compile a new + binary. This adds complication for researchers: A lot of extra + effort is required. Besides, the training job submission program + may not allow running arbitrary binaries. + +This design doc discusses PaddlePaddle's new distributed training +architecture that addresses the above limitations. + +## Analysis + +We will assume the user writes the trainer program by Python, the same +analysis holds if the trainer program is written in C++. + +### Limitation 1 + +If we look at the Python code that the user writes, there are two +kinds of functionalities: + +- The training logic such as load / save model and print log. +- The neural network definition such as the definition of the data + layer, the fully connected layer, the cost function and the + optimizer. + +When we training with PaddlePaddle v0.10.0 distributedly, multiple +replicated Python instances are running on different nodes: both the +training logic and the neural network computation is replicated. + +The tasks that should only run once all belong to the training logic, +if we only replicate the neural network computation, but do **not** +replicate the training logic, the limitation could be solved. + +### Limitation 2 + +Model parallelism means running a single model on multiple nodes by +partitioning the model onto different nodes and managing the +inter-model-shard communications. + +PaddlePaddle should be able to modify the nerual network computation +definition to support model parallelism automatically. However, the +computation is only specified in Python code, and PaddlePaddle can not +modify Python code. + +Just like compiler uses a intermediate representation (IR) so that +programmer does not need to manually optimize their code in most of +the cases - the compiler will optimize the IR: + + + +We can have our own IR too: PaddlePaddle can support model parallel by +converting the IR so the user no longer need to manually do it in +Python: + + + +The IR for PaddlePaddle after refactor is called `Block`, it specifies +the computation dependency graph and the variables used in the +computation. + +### Limitation 3 + +The user can not directly specify the parameter update rule for the +parameter server because the parameter server does not use the same +computation definition as the trainer. Instead, the update rule is +baked in the parameter server. The user can not specify the update +rule in the same way of specifying the trainer computation. + +This could be fixed by making the parameter server run the same +computation definition as the trainer. For a detailed explanation, +please +see +[Design Doc: Operation Graph Based Parameter Server](./dist_train.md) + +## Distributed Training Architecture + +The new distributed training architecture can address the above +limitations. Below is the illustration: + + + +The architecture includes major components: *PaddlePaddle Python*, +*PaddlePaddle converter* and *PaddlePaddle runtime*: + +### PaddlePaddle Python + +PaddlePaddle Python is the Python library that user's Python trainer +invoke to build the neural network topology, start training, etc. + +```Python +paddle.init() +input = paddle.op.recordIO("/home/data/mnist.recordio") # file stored on the cluster +img, label = input[0], input[1] +hidden = paddle.layer.fc(input=img, size=200, act=paddle.activation.Tanh()) +prediction = paddle.layer.fc(input=img, size=10, act=paddle.activation.Softmax()) +cost = paddle.layer.classification_cost(input=prediction, label=label) +optimizer = paddle.optimizer.SGD(cost, learning_rate=0.01) +session = paddle.session.NewRemote(num_trainer=3, num_ps=2, GPU_per_trainer=1) +for i in range(1000): + _, cost_val = session.eval(targets=[cost, optimizer]) + print cost_val +``` + +The code above is a typical Python trainer code, the neural network +topology is built using helper functions such as +`paddle.layer.fc`. The training is done by calling `session.eval` +iteratively. + +#### session.eval + +As shown in the graph, `session.eval` sends the IR and the evaluation +inputs/targets to the PaddlePaddle cluster for evaluation. The +targets can be any variable in the computation graph. When the target +is the `optimizer` variable, the neural network will be optimized +once. When the target is the `cost` variable, `session.eval` returns +the cost value. + +The Python `session` is a wrapper of the C++ `Session` class. For more +information about `Session`, please +see [Design Doc: Session](./session.md). + +### PaddlePaddle Converter + +PaddlePaddle converter automatically converts the IR in the request +(IR and evaluation inputs/targets) from PaddlePaddle Python to new +partitioned IRs and dispatch the new IRs and evaluation inputs/targets +to different PaddlePaddle runtimes. Below are the steps: + +1. Add `feed` OP that feeds the eval inputs, and `fetch` OP that + fetches the eval targets to the IR. + +1. Extract a new computation (sub)graph with `feed` and `fetch` OP as + the boundary. The runtime does not need to run the OP that is not + dependent by the `fetch` OP. + +1. Optimizes the computation graph. + +1. Place the OPs in the graph onto different devices on different + PaddlePaddle runtime according to a placement algorithm and device + constraint specified by the user. + +1. Partition the graph according to runtime boundaries and add `send` / + `recv` OP pair on the runtime boundaries. + +1. Dispatch the partitioned graph to different PaddlePaddle runtimes. + +1. PaddlePaddle runtimes with the `fetch` OP reports evaluation + results back to the converter, the convert reports the evaluation + results back to the PaddlePaddle Python. + +The output IRs will be cached to optimize the conversion latency. + + +#### Placement Algorithm + +Our first implementation will only support "trainer-parameter server" +placement: the parameters, initializers, and optimizers are placed on +the PaddlePaddle runtimes with the parameter server role. And +everything else will be placed on the PaddlePaddle runtimes with the +trainer role. This has the same functionality of our +"trainer-parameter server" architecture of PaddlePaddle v0.10.0, but +is more general and flexible. + +In the future, we will implement the general placement algorithm, +which makes placements according to the input IR, and a model of +device computation time and device communication time. Model +parallelism requires the general placement algorithm. + + +### PaddlePaddle Runtime + +The PaddlePaddle runtime owns multiple devices (e.g., CPUs, GPUs) and +runs the IR. The runtime does not need to do OP placement since it's +already done by the converter. + + +### Local Training Architecture + +The local training architecture will be the same as the distributed +training architecture, the differences are everything runs locally, +and there is just one PaddlePaddle runtime: + + + + +### Training Data + +In PaddlePaddle v0.10.0, training data is typically read +with [data reader](../reader/README.md) from Python. This approach is +no longer efficient when training distributedly since the Python +process no longer runs on the same node with the trainer processes, +the Python reader will need to read from the distributed filesystem +(assuming it has the access) and send to the trainers, doubling the +network traffic. + +When doing distributed training, the user can still use Python data +reader: the training data are sent with `session.eval`. However should +be used for debugging purpose only. The users are encouraged to use +the read data OPs. + + +## References: + +[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) + +[2] [TensorFlow: A System for Large-Scale Machine Learning](https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf) diff --git a/doc/design/ops/dist_train.md b/doc/design/refactor/parameter_server.md similarity index 100% rename from doc/design/ops/dist_train.md rename to doc/design/refactor/parameter_server.md diff --git a/doc/design/refactor/src/compiler.graffle b/doc/design/refactor/src/compiler.graffle new file mode 100644 index 0000000000000000000000000000000000000000..8cc678fea3c820103e7ce81f7a5d625d6c1d92de Binary files /dev/null and b/doc/design/refactor/src/compiler.graffle differ diff --git a/doc/design/refactor/src/compiler.png b/doc/design/refactor/src/compiler.png new file mode 100644 index 0000000000000000000000000000000000000000..65d34f841afce9756def07dd8ecb9ca44e658bfe Binary files /dev/null and b/doc/design/refactor/src/compiler.png differ diff --git a/doc/design/ops/src/dist-graph.graffle b/doc/design/refactor/src/dist-graph.graffle similarity index 100% rename from doc/design/ops/src/dist-graph.graffle rename to doc/design/refactor/src/dist-graph.graffle diff --git a/doc/design/ops/src/dist-graph.png b/doc/design/refactor/src/dist-graph.png similarity index 100% rename from doc/design/ops/src/dist-graph.png rename to doc/design/refactor/src/dist-graph.png diff --git a/doc/design/refactor/src/distributed_architecture.graffle b/doc/design/refactor/src/distributed_architecture.graffle new file mode 100644 index 0000000000000000000000000000000000000000..f8496e57326c38de7468eb452a7713291d57653c Binary files /dev/null and b/doc/design/refactor/src/distributed_architecture.graffle differ diff --git a/doc/design/refactor/src/distributed_architecture.png b/doc/design/refactor/src/distributed_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..410c4510c6aab301dec95e6427fe80ac24e105fe Binary files /dev/null and b/doc/design/refactor/src/distributed_architecture.png differ diff --git a/doc/design/ops/src/local-graph.graffle b/doc/design/refactor/src/local-graph.graffle similarity index 100% rename from doc/design/ops/src/local-graph.graffle rename to doc/design/refactor/src/local-graph.graffle diff --git a/doc/design/ops/src/local-graph.png b/doc/design/refactor/src/local-graph.png similarity index 100% rename from doc/design/ops/src/local-graph.png rename to doc/design/refactor/src/local-graph.png diff --git a/doc/design/refactor/src/local_architecture.graffle b/doc/design/refactor/src/local_architecture.graffle new file mode 100644 index 0000000000000000000000000000000000000000..cc7783c45381f25ded0b898649322c81418ad317 Binary files /dev/null and b/doc/design/refactor/src/local_architecture.graffle differ diff --git a/doc/design/refactor/src/local_architecture.png b/doc/design/refactor/src/local_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..4b999538b7825c805292ee28b5e3256d5543bd09 Binary files /dev/null and b/doc/design/refactor/src/local_architecture.png differ diff --git a/doc/design/refactor/src/paddle-compile.graffle b/doc/design/refactor/src/paddle-compile.graffle new file mode 100644 index 0000000000000000000000000000000000000000..a6348cc3dbcaca923c6e794681b2edb85cb9f8f6 Binary files /dev/null and b/doc/design/refactor/src/paddle-compile.graffle differ diff --git a/doc/design/refactor/src/paddle-compile.png b/doc/design/refactor/src/paddle-compile.png new file mode 100644 index 0000000000000000000000000000000000000000..e0f13d551ac41afaec627a57dea79356464bf0bf Binary files /dev/null and b/doc/design/refactor/src/paddle-compile.png differ diff --git a/doc/faq/index_cn.rst b/doc/faq/index_cn.rst index acbf4c87ae5242f6cfc593a7fddc649ee3a70d7c..b3ecfba791ead8349ded018a30059b03eacbdacd 100644 --- a/doc/faq/index_cn.rst +++ b/doc/faq/index_cn.rst @@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 * 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。 -* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 \ No newline at end of file +* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 + +19. PaddlePaddle如何输出多个层 +------------------------------ + +* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下: + +.. code-block:: python + + inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters) + +* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下: + +.. code-block:: python + + out = inferer.infer(input=data_batch, flatten_result=False, field=["value"]) + +这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。 + +20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用 +------------------------------------------------------------- + +* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。 + +* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。 + +21. dropout 使用 +----------------- + +* 在PaddlePaddle中使用dropout有两种方式 + + * 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下: + + .. code-block:: python + + fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5)) + + * 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下: + + .. code-block:: python + + fc = paddle.layer.fc(input=input) + drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5) + +* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。 + +* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。 + +* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。 + +22. 如何设置学习率退火(learning rate annealing) +------------------------------------------------ + +在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下: + +.. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_decay_a=0.5, + learning_rate_decay_b=0.75, + learning_rate_schedule="poly",) + +PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下: + +* "constant" + + lr = learning_rate + +* "poly" + + lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b) + + 其中,num_samples_processed为已训练样本数,下同。 + +* "caffe_poly" + + lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b) + +* "exp" + + lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b) + +* "discexp" + + lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b)) + +* "linear" + + lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b) + +* "manual" + + 这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下: + + .. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_schedule="manual", + learning_rate_args="1000:1.0,2000:0.9,3000:0.8",) + + 在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。 + +* "pass_manual" + + 这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下: + + .. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_schedule="manual", + learning_rate_args="1:1.0,2:0.9,3:0.8",) + + 在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。 + +23. 出现 :code:`Duplicated layer name` 错误怎么办 +-------------------------------------------------- + +出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。 + diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index 84e33177740ca1652efc09c8081c2519b4366906..30b144d849bec367cd0197b6082889e011193a9a 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -20,7 +20,7 @@ Docker使用入门 docker pull paddlepaddle/paddle:0.10.0 - 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。 + 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用docker.paddlepaddle.org/paddle下载。 - *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。 实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。 diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md new file mode 100644 index 0000000000000000000000000000000000000000..b7aa501db9e5c7378398fad48503f82bff893b60 --- /dev/null +++ b/doc/howto/dev/new_op_en.md @@ -0,0 +1,235 @@ +# How to write a new operator + + - [Background](#Background) + - [Implementing C++ Types](#Implementing_C++_Types) + - [Defining ProtoMaker](#Defining_ProtoMaker) + - [Defining Operator](#Defining_Operator) + - [Registering Operator](#Registering_Operator) + - [Compilation](#Compilation) + - [Python Binding](#Python_Binding) + - [Unit Tests](#Unit_Tests) + +## Background + +Here are the base types needed. For details, please refer to the design docs. + +- `framework::OperatorBase`: Operator (Op)base class. +- `framework::OpKernel`: Base class for Op computation. +- `framework::OperatorWithKernel`: Inherited from OperatorBase, describing an operator with computation. +- `class OpProtoAndCheckerMaker`: Describes an Operator's input, output, attributes and description, mainly used to interface with Python API. + +An operator can be differentiated by whether in has kernel methods. An operator with kernel inherits from `OperatorWithKernel` while the ones without inherit from `OperatorBase`. This tutorial focuses on implementing operators with kernels. In short, an operator includes the following information: + + + Information | Where is it defined +-------------- | :---------------------- +OpProtoMake definition | `.cc`files, Backward Op does not need an OpProtoMake interface. +Op definition | `.cc` files +Kernel implementation | The kernel methods shared between CPU and GPU are defined in `.h` files. CPU-specific kernels live in `.cc` files, while GPU-specific kernels are implemented in `.cu`files. +Registering the Op | Ops are registered in `.cc` files; For Kernel registration, `.cc` files contain the CPU implementation, while `.cu` files contain the GPU implementation. + + +New Operator implementations are added to the list [paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators), with file names in the format `*_op.h` (if applicable), `*_op.cc`, `*_op.cu` (if applicable).** The system will use the naming scheme to automatically build operators and their corresponding Python extensions. ** + + +Let's take matrix multiplication operator, [MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc), as an example to introduce the writing of an Operator with Kernel. + + +## Implementing C++ Types + + +### 1. Defining Class ProtoMaker + +Matrix Multiplication can be written as $Out = X * Y$, meaning that the operation consists of two inputs and pne output. + +First, define `ProtoMaker` to describe the Operator's input, output, and additional comments: + +```cpp +class MulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor), 2D tensor of size (M x K)"); + AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); + AddOutput("Out", "(Tensor), 2D tensor of size (M x N)"); + AddComment(R"DOC( +Two Element Mul Operator. +The equation is: Out = X * Y +)DOC"); + } +}; +``` + +[`MulOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L43)is inherited from`framework::OpProtoAndCheckerMaker`, consisting of 2 variables in the constructor: + + - `framework::OpProto` stores Operator input and variable attribute, used for generating Python API interfaces. + - `framework::OpAttrChecker` is used to validate variable attributes. + +The constructor utilizes `AddInput`, `AddOutput`, and `AddComment`, so that the corresponding information will be added to `OpProto`. + +The code above adds two inputs `X` and `Y` to `MulOp`, an output `Out`, and their corresponding descriptions, in accordance to Paddle's [naming convention](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). + + +An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37) is implemented as follows: + +```cpp +template +class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of scale operator.").NotInGradient(); + AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddComment(R"DOC(Scale operator +The equation is: Out = scale*X +)DOC"); + AddAttr("scale", "scale of scale operator.").SetDefault(1.0); + } +}; +``` + +There are two changes in this example: + +- `AddInput("X","...").NotInGradient()` expresses that input `X` is not involved in `ScaleOp`'s corresponding computation. If an input to an operator is not participating in back-propagation, please explicitly set `.NotInGradient()`. + +- `AddAttr("scale", "...").SetDefault(1.0);` adds `scale`constant as an attribute, and sets the default value to 1.0. + + +### 2. Defining Operator + +The following code defines the interface for MulOp: + +```cpp +class MulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(dim0.size(), 2, + "input X(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("X")); + PADDLE_ENFORCE_EQ(dim1.size(), 2, + "input Y(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("Y")); + PADDLE_ENFORCE_EQ( + dim0[1], dim1[0], + "First matrix's width must be equal with second matrix's height."); + ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + } +}; +``` + +[`MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L22) is inherited from `OperatorWithKernel`. Its `public` member + +```cpp +using framework::OperatorWithKernel::OperatorWithKernel; +``` + +expresses an operator constructor using base class `OperatorWithKernel`, alternatively written as + +```cpp +MulOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} +``` + +`InferShape` interface needs to be re-written.`InferShape` is a constant method and cannot modify Op's member variables, its constant member `const framework::InferShapeContext &ctx` can be used to extract input, output, and attributes. It functions to + + - 1). validate and error out early: it checks input data dimensions and types. + - 2). configures the tensor shape in the output. + +Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later. + +### 3. Defining OpKernel + +`MulKernel` inherits `framework::OpKernel`, which includes the following templates: + +- `typename Place` denotes device type. When different devices, namely the CPU and the GPU, share the same kernel, this template needs to be added. If they don't share kernels, this must not be added. An example of a non-sharing kernel is [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +- `typename T` denotes data type, such as `float` or `double`. + +`MulKernel` types need to rewrite the interface for `Compute`. +- `Compute` takes one input variable `const framework::ExecutionContext& context`. +- Compared with `InferShapeContext`, `ExecutionContext` includes device types, and can similarly extract input, output, and attribute variables. +- `Compute` implements the computation logics of an `OpKernel`. + +`MulKernel`'s implementation of `Compute` is as follows: + + ```cpp + template + class MulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } + }; + ``` + +Note that **different devices (CPU, GPU)share an Op definition; whether or not they share the same `OpKernel` depends on whether `Compute` calls functions that support both devices.** + +`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md). + + +This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file. + +The definition of its corresponding backward operator, if applicable, is similar to that of an forward operator. **Note that a backward operator does not include a `ProtoMaker`**. + +### 4. Registering Operator + +- In `.cc` files, register forward and backward operator classes and the CPU kernel. + + ```cpp + namespace ops = paddle::operators; + REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); + REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + + In that code block, + + - `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`. + - `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient. + - `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`. + + +- Registering GPU Kernel in `.cu` files + - Note that if GPU Kernel is implemented using the `Eigen unsupported` module, then on top of `.cu`, a macro definition `#define EIGEN_USE_GPU` is needed, such as + + ```cpp + // if use Eigen unsupported module before include head files + #define EIGEN_USE_GPU + + namespace ops = paddle::operators; + REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + +### 5. Compilation + +Run the following commands to compile. + +``` +make mul_op +``` + +## Python Binding + +The system will automatically bind to Python and link it to a generated library. + +## Unit Tests + +Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index fda89252e35c382468877e8cab148e5f91d77ac2..510dc28c57f642786e7c64d86961c76ac80014a8 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -28,47 +28,6 @@ ProgramDesc& GetProgramDesc() { return *g_program_desc; } -template <> -AttrType AttrTypeID() { - return BOOLEAN; -} -template <> -AttrType AttrTypeID() { - return INT; -} -template <> -AttrType AttrTypeID() { - return FLOAT; -} -template <> -AttrType AttrTypeID() { - return STRING; -} -template <> -AttrType AttrTypeID>() { - return BOOLEANS; -} -template <> -AttrType AttrTypeID>() { - return INTS; -} -template <> -AttrType AttrTypeID>() { - return FLOATS; -} -template <> -AttrType AttrTypeID>() { - return STRINGS; -} -template <> -AttrType AttrTypeID>>() { - return INT_PAIRS; -} -template <> -AttrType AttrTypeID() { - return BLOCK; -} - Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { @@ -111,14 +70,6 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } - case framework::AttrType::INT_PAIRS: { - std::vector> val(attr_desc.int_pairs_size()); - for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { - val[i].first = attr_desc.int_pairs(i).first(); - val[i].second = attr_desc.int_pairs(i).second(); - } - return val; - } case framework::AttrType::BLOCK: { return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); } diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 48b54b5422de8c45e15a1b7040b78373dce8fa3a..488fa38faf12ee51087643f79295f36bfd33ee22 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,10 +27,10 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, std::vector, std::vector, - std::vector, - std::vector>, BlockDesc*> +// The order should be as same as framework.proto +typedef boost::variant, + std::vector, std::vector, bool, + std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; @@ -38,7 +38,10 @@ typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template -AttrType AttrTypeID(); +inline AttrType AttrTypeID() { + Attribute tmp = T(); + return static_cast(tmp.which() - 1); +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc); diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 6fcfe6de25737b66a2ea6c1a438636f072a513bb..951c7afbc14e2d9119169c1351d38ff0b67bdc5b 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,17 +22,11 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; - INT_PAIRS = 6; - BOOLEAN = 7; - BOOLEANS = 8; - BLOCK = 9; + BOOLEAN = 6; + BOOLEANS = 7; + BLOCK = 8; } -message IntPair { - required int32 first = 1; - required int32 second = 2; -}; - // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -46,7 +40,6 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; - repeated IntPair int_pairs = 9; optional bool b = 10; repeated bool bools = 11; optional int32 block_idx = 12; diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h index 86ffe387366409d81a91740cc8cea886e618f7e2..40dd8c618aa2b70d410130e12efc54520218afea 100644 --- a/paddle/gserver/activations/MKLDNNActivation.h +++ b/paddle/gserver/activations/MKLDNNActivation.h @@ -100,6 +100,7 @@ public: if (cnt_ == act.value->getElementCnt()) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; cnt_ = act.value->getElementCnt(); stream_.reset(new MKLDNNStream()); auto eng = CPUEngine::Instance().getEngine(); @@ -110,7 +111,6 @@ public: float alpha = getAlpha(); float beta = getBeta(); - /// forward pipelineFwd_.clear(); val_ = std::dynamic_pointer_cast(act.value); if (val_ == nullptr) { @@ -152,6 +152,7 @@ public: if (!needResetBwd_) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; needResetBwd_ = false; mkldnn::algorithm algo = getAlgo(this->getName()); float alpha = getBwdAlpha(); diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 88b047c89bd40aba1afc456c22a2870c62989c1c..9a0abd291ae8fae43b0e95c7371f3ce35d1261ec 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue( // create buffer and reorder if input value do not match cpuInVal_ = nullptr; cvtInVal_ = nullptr; - if (inputIsOnlyMKLDNN()) { - MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); - CHECK(dnnIn) << "Input should be MKLDNNMatrix"; - if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) { - CHECK_EQ(dnnIn->getFormat(), format::nc); + + MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); + CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr); + if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + if (dnnIn) { + if (dnnIn->getFormat() == format::nc) { CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format"; // create a new one with nchw format and same data memory::dims inDims = memory::dims{bs_, ic_, 1, 1}; dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); - CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()); } - in = dnnIn; + if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + cpuInVal_ = dnnIn; + in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); + cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); + CHECK(cvtInVal_) << "should not be emptry"; } else { - const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; - cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_); + cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) { // create new mkldnn matrix in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); @@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( } else { wgtValBwdData_ = wgtVal_; } - VLOG(MKLDNN_FMTS) << "weight value format for backward data" + VLOG(MKLDNN_FMTS) << "weight value format for backward data: " << wgtValBwdData_->getFormat(); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index afd092666bf8b8a3389b36aa1f0edb256a9968e6..8cbfbd0d2b9f2149f7c959aec5c4ae1de952f903 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias) { + format wgtFmt = format::oihw; + if (inVal_->getFormat() == format::nChw8c) { + wgtFmt = format::oIhw8i; + } else if (inVal_->getFormat() == format::nChw16c) { + wgtFmt = format::oIhw16i; + } wgt = MKLDNNMatrix::create( - weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_); + weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_); wgt->downSpatial(); + VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); bias = (biases_ && biases_->getW()) ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index d8555a833187ddf64b096135e920e5be2b3a8c2f..c09fd89462ef4fdaeaae3e122f96b0cc6ce373ea 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -115,6 +115,7 @@ public: copySeqInfoToOutputs(); size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt(); if (inputElemenCnt_ != elemenCnt) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; // reset when input total sizes changed, not only the batchsize inputElemenCnt_ = elemenCnt; reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); @@ -142,6 +143,7 @@ public: void backward(const UpdateCallback& callback) override { if (needResetBwd_) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); needResetBwd_ = false; } diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( - num_samples, infer_width, inference_data, label_data, accuracy_data); + AccuracyCudaKernel<<< + 1, PADDLE_CUDA_NUM_THREADS, 0, + reinterpret_cast( + ctx.device_context()) + .stream()>>>(num_samples, infer_width, inference_data, label_data, + accuracy_data); } }; diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h index 2f40c059033ec649b29f6ecdee4fcedd128a63a6..ac3aeaf41e206c1deb74c7022c36f02c4777a84b 100644 --- a/paddle/operators/crop_op.h +++ b/paddle/operators/crop_op.h @@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel { auto out_stride = framework::stride(out->dims()); auto offsets = context.Attr>("offsets"); PADDLE_ENFORCE_EQ( - x->dims().size(), offsets.size(), + x->dims().size(), static_cast(offsets.size()), "Offsets size should be equal to dimension size of input tensor."); int64_t offset = 0; - for (int i = 0; i < offsets.size(); ++i) { + for (size_t i = 0; i < offsets.size(); ++i) { offset += (x_stride[i] * offsets[i]); } StridedMemcpy(context.device_context(), x_data + offset, x_stride, @@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { d_x->mutable_data(context.GetPlace()); auto offsets = context.Attr>("offsets"); Eigen::array, D> paddings; - for (int i = 0; i < D; ++i) { + for (size_t i = 0; i < D; ++i) { paddings[i].first = offsets[i]; paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; } diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index b11dc1472d153dd188a0b3553d6950774216a3fd..2e16201e74c153888594ebe6679fb0036734dad4 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + "Input(Label) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); - if (ctx.Attr("soft_label")) { + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "If Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "If Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } ctx.Output("Y")->Resize({x->dims()[0], 1}); @@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) must not be null."); + "Input(Label) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), - "Input(Y@GRAD) must not be null."); + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); auto x = ctx.Input("X"); auto label = ctx.Input("Label"); auto dy = ctx.Input(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, + "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, - "Input(Label)'s rank must be 2."); + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], - "The 1st dimension of Input(X) and Input(Label) must " + "The 1st dimension of Input(X) and Input(Label) should " "be equal."); PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], - "The 1st dimension of Input(X) and Input(Y@Grad) must " + "The 1st dimension of Input(X) and Input(Y@Grad) should " "be equal."); PADDLE_ENFORCE_EQ(dy->dims()[1], 1, - "The 2nd dimension of Input(Y@Grad) must be 1."); - if (ctx.Attr("soft_label")) { + "The 2nd dimension of Input(Y@Grad) should be 1."); + if (ctx.Attr("softLabel")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == true, The 2nd dimension of " - "Input(X) and Input(Label) must be equal."); + "When Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == false, The 2nd dimension of " - "Input(Label) must be 1."); + "When Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); } auto dx = ctx.Output(framework::GradVarName("X")); @@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { CrossEntropyOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of CrossEntropyOp"); - AddInput("Label", "The second input of CrossEntropyOp"); - AddOutput("Y", "The output of CrossEntropyOp"); - AddAttr("soft_label", "Is soft label. Default zero.") + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor, default Tensor), the ground truth which is " + "a 2-D tensor. " + "When softLabel is set to false, `Label` is a Tensor with shape " + "[N x 1]. " + "When softLabel is set to true, `Label` is a Tensor " + "with shape [N x K]."); + AddOutput("Y", + "(Tensor, default Tensor), a 2-D tensor " + "with shape [N x 1]. The cross entropy loss."); + AddAttr( + "softLabel", + "(bool, default false), a flag to indicate whether to interpretate " + "the given labels as soft labels.") .SetDefault(false); - AddComment(R"DOC( CrossEntropy Operator. It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: - soft_label = False, Label[i, 0] indicates the class index for sample i: + softLabel = false, Label[i, 0] indicates the class index for sample i: Y[i] = -log(X[i, Label[i]]) 2) Soft-label cross-entropy: - soft_label = True, Label[i, j] indicates the soft label of class j + softLabel = true, Label[i, j] indicates the soft label of class j for sample i: Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 2989e55075f5247cb6185afc96e781bb622204fb..18e44d77c9f62b296dc57952e546f844670c7d57 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -32,22 +32,45 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, } } +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + template __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, - const int N, const int D) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - T sum = static_cast(0); - for (int j = 0; j < D; j++) { - sum += label[i * D + j] * TolerableValue()(log(X[i * D + j])); - } - Y[i] = -sum; + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; } -// TODO(qingqing): make zero setting an common function. +// TODO(qingqing): make zero setting a common function. template -__global__ void zero(T* X, const int N) { +__global__ void Zero(T* X, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { X[i] = 0.0; @@ -71,13 +94,10 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { - // TOOD(qingqing): optimize for this kernel - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - for (int j = 0; j < D; ++j) { - int idx = i * D + j; - dX[idx] = -label[idx] * dY[i] / X[idx]; - } + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < N * D) { + int row_ids = ids / D; + dX[ids] = -label[ids] * dY[row_ids] / X[ids]; } } @@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - auto label = ctx.Input("Label"); + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); + const T* x_data = x->data(); + T* y_data = y->mutable_data(ctx.GetPlace()); - int n = x->dims()[0]; - int d = x->dims()[1]; - int block = 512; - int grid = (n + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + if (ctx.Attr("softLabel")) { auto* label_data = ctx.Input("Label")->data(); - SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, - d); + int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, class_num); } else { auto* label_data = ctx.Input("Label")->data(); - CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + int block = 512; + int grid = (batch_size + block - 1) / block; + CrossEntropyKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(y_data, x_data, label_data, + batch_size, class_num); } } }; @@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); + "This kernel only runs on GPU device."); + + const Tensor* x = ctx.Input("X"); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); + const T* dy_data = + ctx.Input(framework::GradVarName("Y"))->data(); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* x_data = x->data(); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; - int n = x->dims()[0]; - int d = x->dims()[1]; int block = 512; - int grid = (n * d + block - 1) / block; - zero<<>>(dx_data, n * d); - grid = (n + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. - if (ctx.Attr("soft_label")) { + int grid = (batch_size * class_num + block - 1) / block; + + if (ctx.Attr("softLabel")) { auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<>>( - dx_data, dy_data, x_data, label_data, n, d); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } else { + Zero<<( + ctx.device_context()) + .stream()>>>(dx_data, batch_size * class_num); + auto* label_data = label->data(); - CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, - label_data, n, d); + grid = (batch_size + block - 1) / block; + CrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(dx_data, dy_data, x_data, label_data, + batch_size, class_num); } } }; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 942a532f642e74a112a084eb60d1daaa396ca578..255b2e9f5ea7566cca7fd3914e38da804b7c7006 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/hostdevice.h" @@ -20,6 +21,9 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; template struct TolerableValue { @@ -38,32 +42,27 @@ class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto x = ctx.Input("X"); - auto y = ctx.Output("Y"); - - auto* x_data = x->data(); - y->mutable_data(ctx.GetPlace()); - auto* y_data = y->data(); - - int batch_size = x->dims()[0]; - int class_num = x->dims()[1]; - - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - T sum = static_cast(0); - for (int j = 0; j < class_num; ++j) { - sum += - label_data[index] * TolerableValue()(std::log(x_data[index])); - y_data[i] = -sum; - index++; - } - } + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* labels = ctx.Input("Label"); + Tensor* y = ctx.Output("Y"); + T* y_data = y->mutable_data(ctx.GetPlace()); + + const int batch_size = x->dims()[0]; + if (ctx.Attr("softLabel")) { + auto prob = EigenMatrix::From(*x); + auto lbl_mat = EigenMatrix::From(*labels); + auto loss = EigenMatrix::From(*y); + + loss.device(ctx.GetEigenDevice()) = + -((lbl_mat * prob.log().unaryExpr(TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); } else { - auto* label_data = ctx.Input("Label")->data(); + const int class_num = x->dims()[1]; + const T* x_data = x->data(); + + const int* label_data = labels->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; y_data[i] = -TolerableValue()(std::log(x_data[index])); @@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); + "This kernel only runs on CPU."); + const Tensor* x = ctx.Input("X"); + const Tensor* dy = ctx.Input(framework::GradVarName("Y")); + const Tensor* label = ctx.Input("Label"); + Tensor* dx = ctx.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - auto x = ctx.Input("X"); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("Label"); - - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->data(); - auto* x_data = x->data(); - - int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - - // TODO(qingqing): make zero setting an common function. - if (ctx.Attr("soft_label")) { - auto* label_data = ctx.Input("Label")->data(); - int index = 0; - for (int i = 0; i < batch_size; ++i) { - for (int j = 0; j < class_num; ++j) { - dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; - index++; - } - } + if (ctx.Attr("softLabel")) { + auto x_mat = EigenMatrix::From(*x); + auto dy_mat = EigenMatrix::From(*dy); + auto lbl_mat = EigenMatrix::From(*label); + auto dx_mat = EigenMatrix::From(*dx); + + dx_mat.device(ctx.GetEigenDevice()) = + -(lbl_mat * dy_mat.broadcast(Eigen::DSizes(1, class_num)) / + x_mat); } else { - auto* label_data = label->data(); + int batch_size = x->dims()[0]; + const T* dy_data = dy->data(); + const T* x_data = x->data(); + const int* label_data = label->data(); + + // TODO(qingqing): make zero setting a common function. memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); int index = i * class_num + label_data[i]; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 708344046760691aa2da562eb1ee3d8b130c5f18..62f63b4f3c876e084e2468001e8bcb9310d16a82 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable<<>>(output, table, ids, N, K, D); + LookupTable<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(output, table, ids, N, K, D); } }; @@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGrad<<>>(d_table, d_output, ids, N, - K, D); + LookupTableGrad<<< + grids, threads, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(d_table, d_output, ids, N, K, D); } }; diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3600f199770c4b8c9a6561b4c270a91bc8b20c0b --- /dev/null +++ b/paddle/operators/lstm_unit_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LstmUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), + "Output(H) of LSTM should not be null."); + + auto *x = ctx.Input("X"); + auto *c_prev = ctx.Input("C_prev"); + + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + "Dimension of FC should equal to prev state * 4"); + + int b_size = c_prev->dims()[0]; // batch size + int s_dim = c_prev->dims()[1]; // state dim + ctx.Output("C")->Resize({b_size, s_dim}); + ctx.Output("H")->Resize({b_size, s_dim}); + } +}; + +template +class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LstmUnitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "FC input before the non-linear activation."); + AddInput( + "C_prev", + "The cell state tensor of last time-step in the Lstm Unit operator."); + AddOutput("C", "The cell tensor of Lstm Unit operator."); + AddOutput("H", "The hidden state tensor of Lstm Unit operator."); + + AddComment(R"DOC(Lstm-Unit Operator + +Equation: + i, f, o, j = split(X) + C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) + H = C * sigm(o) + +)DOC"); + AddAttr("forget_bias", "The forget bias of Lstm Unit.") + .SetDefault(0.0); + } +}; + +class LstmUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); + ctx.Output(framework::GradVarName("C_prev")) + ->Resize(ctx.Input("C_prev")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, + lstm_unit_grad, ops::LstmUnitGradOp); +REGISTER_OP_CPU_KERNEL(lstm_unit, + ops::LstmUnitKernel); +REGISTER_OP_CPU_KERNEL( + lstm_unit_grad, ops::LstmUnitGradKernel); diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e5e4978994c281416a65af5f8ffdec688768d63 --- /dev/null +++ b/paddle/operators/lstm_unit_op.cu @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ Dtype cuda_sigmoid(const Dtype x) { + return Dtype(1) / (Dtype(1) + exp(-x)); +} + +template +__device__ Dtype cuda_tanh(const Dtype x) { + return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x)); +} + +template +__global__ void LSTMUnitKernel(const int nthreads, const int dim, + const T* C_prev, const T* X, T* C, T* H, + const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + + const T* X_offset = X + 4 * dim * n; + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = f * c_prev + i * g; + C[index] = c; + const T tanh_c = cuda_tanh(c); + H[index] = o * tanh_c; + } +} + +template +__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, + const T* C_prev, const T* X, const T* C, + const T* H, const T* C_diff, + const T* H_diff, T* C_prev_diff, + T* X_diff, const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const T* X_offset = X + 4 * dim * n; + T* c_prev_diff = C_prev_diff + index; + T* X_diff_offset = X_diff + 4 * dim * n; + T* i_diff = X_diff_offset + d; + T* f_diff = X_diff_offset + 1 * dim + d; + T* o_diff = X_diff_offset + 2 * dim + d; + T* g_diff = X_diff_offset + 3 * dim + d; + + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = C[index]; + const T tanh_c = cuda_tanh(c); + const T c_term_diff = + C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[index] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } +} + +template +class LstmUnitOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + int block = 512; + int n = b_size * D; + int grid = (n + block - 1) / block; + + LSTMUnitKernel<<>>(n, D, C_prev, X, C, H, forget_bias); + } +}; + +template +class LstmUnitGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int block = 512; + int n = N * D; + int grid = (n + block - 1) / block; + + LSTMUnitGradientKernel<<>>(n, D, C_prev, X, C, H, C_diff, + H_diff, C_prev_diff, X_diff, + forget_bias); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h new file mode 100644 index 0000000000000000000000000000000000000000..683034fe15df8cabfdff5e856adb5c0467055064 --- /dev/null +++ b/paddle/operators/lstm_unit_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +class LstmUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + for (int n = 0; n < b_size; ++n) { + for (int d = 0; d < D; ++d) { + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = f * c_prev + i * g; + C[d] = c; + const T tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + } + } +}; + +template +class LstmUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + for (int n = 0; n < N; ++n) { + for (int d = 0; d < D; ++d) { + T* c_prev_diff = C_prev_diff + d; + T* i_diff = X_diff + d; + T* f_diff = X_diff + 1 * D + d; + T* o_diff = X_diff + 2 * D + d; + T* g_diff = X_diff + 3 * D + d; + + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = C[d]; + const T tanh_c = tanh(c); + const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + C_diff += D; + H_diff += D; + X_diff += 4 * D; + C_prev_diff += D; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index def4b01da098fc960ce7c0e497732fbcc2579945..ba653afa2cb175ae2e5e21088b6dc7ba76a6018f 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -48,6 +48,32 @@ void gemm(const platform::DeviceContext& context, beta, C, ldc); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 71563b77b4b262c3f1e17ae7c4381da56ba780a3..649f1f352c2a4a5ebaa0cb00ffb2e4de8aa4961a 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -63,6 +63,42 @@ void gemm(const platform::DeviceContext& context, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index d8518e77fa7b4abdbcf08b7983013c24806e14ca..43306fca73387b7b212f556a2b187df113a1b327 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C); +// gemm wrapper with stride args for matrix uncontinuous in memory +template +void gemm(const platform::DeviceContext& context, const bool transA, + const bool transB, const int M, const int N, const int K, + const T alpha, const T* A, const int lda, const T* B, const int ldb, + const T beta, T* C, const int ldc); + // matrix multiply with continuous memory template void matmul(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 7e339457f7f08ff16162f399064a4b4dca594d7f..f272f7e5135e7092618b8c94ee55faf1cfd8e8a5 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) { EXPECT_EQ(out_ptr[8], 29); delete gpu_place; } + +TEST(math_function, gemm_notrans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + // numpy code: + // a = np.arange(6).reshape(2, 3) + // b = np.arange(12).reshape(3, 4)[:, 1:] + // c = np.arange(8).reshape(2, 4)[:, 1:] + // out = np.arange(8).reshape(2, 4) + // out[:, 1:] = np.dot(a, b) + c + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} + +TEST(math_function, gemm_trans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} #endif + +TEST(math_function, gemm_notrans_cblas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +} + +TEST(math_function, gemm_trans_clbas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +} diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e77b86b5698a263b850a973cd1b8644a0aa2201 --- /dev/null +++ b/paddle/operators/multiplex_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/multiplex_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class MultiplexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) shouldn't be null."); + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + auto num_ins = ins.size(); + PADDLE_ENFORCE(num_ins > 2, + "multiplex operator should have more than 2 inputs."); + PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, + "The first input must be a index vector."); + auto in_dim = ins[1]->dims(); + + for (size_t i = 2; i < num_ins; i++) { + auto dim = ins[i]->dims(); + PADDLE_ENFORCE( + in_dim == dim, + "All the input tensors except the first one must have the same size"); + } + out->Resize(in_dim); + } +}; + +class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MultiplexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); + AddOutput("Out", "The output tensor of multiplex operator."); + AddComment(R"DOC(Multiplex operator + +Multiplex multiple tensors according to the index provided by the first +input tensor. + +ins[0]: the index tensor. +ins[1:N]: the candidate output tensors. +For each index i from 0 to batchSize - 1, the output is the i-th row of the +the (index[i] + 1)-th tensor. + +For i-th row of the output tensor: + +y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) + +where y is the output tensor. `x_{k}` is the k-th input tensor +and `k = x{0}[i] + 1`. + +)DOC"); + } +}; + +class MultiplexGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + "Output(X@Grad) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto ins = ctx.MultiInput("X"); + // don't compute gradient for index (ins[0]) + for (size_t i = 1; i < ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->Resize(ins[i]->dims()); + } + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, + ops::MultiplexGradOp); +REGISTER_OP_CPU_KERNEL( + multiplex, ops::MultiplexCPUKernel); +REGISTER_OP_CPU_KERNEL( + multiplex_grad, + ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4736f15bd594178168e3bcf799142d0fc18bff13 --- /dev/null +++ b/paddle/operators/multiplex_op.cu @@ -0,0 +1,95 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/multiplex_op.h" + +namespace paddle { +namespace operators { + +template +class MultiplexGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T), stream); + } + } +}; + +template +class MultiplexGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T), stream); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + multiplex, ops::MultiplexGPUKernel); +REGISTER_OP_GPU_KERNEL( + multiplex_grad, + ops::MultiplexGradGPUKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h new file mode 100644 index 0000000000000000000000000000000000000000..98466426bd90bc30a22ecf74e6739e2d4ad1d21d --- /dev/null +++ b/paddle/operators/multiplex_op.h @@ -0,0 +1,78 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +template +class MultiplexCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(static_cast(k), ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T)); + } + } +}; + +template +class MultiplexGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T)); + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sequence_avg_pool_op.cc b/paddle/operators/sequence_pool_op.cc similarity index 53% rename from paddle/operators/sequence_avg_pool_op.cc rename to paddle/operators/sequence_pool_op.cc index 9815b8f3a8d813959949bbfedc79f404721a8216..73f9cb879a2ef690909428b3b672b12717a6a02c 100644 --- a/paddle/operators/sequence_avg_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/sequence_avg_pool_op.h" +#include "paddle/operators/sequence_pool_op.h" namespace paddle { namespace operators { -class SequenceAvgPoolOp : public framework::OperatorWithKernel { +class SequencePoolOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SequencePoolOp should not be null."); PADDLE_ENFORCE_NOT_NULL( ctx.OutputVar("Out"), - "Output(Out) of SequenceAvgPoolOp should not be null."); + "Output(Out) of SequencePoolOp should not be null."); auto* x = ctx.Input("X"); auto dims = x->dims(); @@ -42,21 +42,45 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel { } }; -class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker { +class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequenceAvgPoolOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) + SequencePoolOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of SequenceAvgPoolOp."); - AddOutput("Out", "The output of SequenceAvgPoolOp."); + AddInput("X", + "A float LoDTensor, the variable-length input of SequencePoolOp"); + AddOutput( + "Out", + "A float LoDTensor, the variable-length output of SequencePoolOp."); + AddAttr( + "strategy", + "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") + .SetDefault(AVERAGE) + .InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST}); AddComment(R"DOC( - SequenceAvgPoolOp averages features of all time-steps of each instance. - More detailed comments will be added later. + SequencePoolOp pools features of all time-steps of each instance. + + For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps: + + Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7]. + Besides, for the sake of simplicity, we assume M=1 and N=1, + and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. + + Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr. + And for different strategy, the value of Out is as follows: + + - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 + - SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 + - SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), + 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) + - MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) + - LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) + - FIRST: [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) )DOC"); } }; -class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { +class SequencePoolGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -84,12 +108,10 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp, - ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad, - ops::SequenceAvgPoolGradOp); +REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, + sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OP_CPU_KERNEL( - sequence_avg_pool, - ops::SequenceAvgPoolKernel); + sequence_pool, ops::SequencePoolKernel); REGISTER_OP_CPU_KERNEL( - sequence_avg_pool_grad, - ops::SequenceAvgPoolGradKernel); + sequence_pool_grad, + ops::SequencePoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.cu b/paddle/operators/sequence_pool_op.cu similarity index 74% rename from paddle/operators/sequence_avg_pool_op.cu rename to paddle/operators/sequence_pool_op.cu index bc9d1611fccd17c99b914b6ef59995288a9ebbd6..66850772d501f873cf754205c19e9d0c0090370a 100644 --- a/paddle/operators/sequence_avg_pool_op.cu +++ b/paddle/operators/sequence_pool_op.cu @@ -14,12 +14,11 @@ #define EIGEN_USE_GPU -#include "paddle/operators/sequence_avg_pool_op.h" +#include "paddle/operators/sequence_pool_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - sequence_avg_pool, - ops::SequenceAvgPoolKernel); + sequence_pool, ops::SequencePoolKernel); REGISTER_OP_GPU_KERNEL( - sequence_avg_pool_grad, - ops::SequenceAvgPoolGradKernel); + sequence_pool_grad, + ops::SequencePoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.h b/paddle/operators/sequence_pool_op.h similarity index 62% rename from paddle/operators/sequence_avg_pool_op.h rename to paddle/operators/sequence_pool_op.h index ebe0956344eb71d0fb2836f1b4a989ac546d9f78..231614b4c1cb0eb1901b1720e933aed5cbb25f77 100644 --- a/paddle/operators/sequence_avg_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -28,54 +28,85 @@ template using EigenMatrix = framework::EigenMatrix; +enum SeqPoolType { + AVERAGE = 0, + SUM = 1, + SQRT = 2, // square_root_n + MAX = 3, + LAST = 4, + FIRST = 5 +}; + template -class SequenceAvgPoolKernel : public framework::OpKernel { +class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); + int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod(); + auto lod = in->lod()[0]; int64_t w = in->numel() / dims[0]; out->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { - Tensor in_t = in->Slice(static_cast(lod[0][i]), - static_cast(lod[0][i + 1])); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + Tensor in_t = + in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); Tensor out_t = out->Slice(i, i + 1); - int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); - out_e.device(place) = in_e.mean(Eigen::array({{0}})); + + switch (strategy) { + case AVERAGE: + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + break; + case SUM: + out_e.device(place) = in_e.sum(Eigen::array({{0}})); + break; + default: + PADDLE_THROW("unsupported pooling strategy"); + } } } }; template -class SequenceAvgPoolGradKernel : public framework::OpKernel { +class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); + int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod(); + auto lod = in->lod()[0]; int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { - auto in_g_t = in_g->Slice(static_cast(lod[0][i]), - static_cast(lod[0][i + 1])); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + auto in_g_t = in_g->Slice(static_cast(lod[i]), + static_cast(lod[i + 1])); auto out_g_t = out_g->Slice(i, i + 1); - int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); Eigen::DSizes bcast(h, 1); - in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + + switch (strategy) { + case AVERAGE: + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + break; + case SUM: + in_g_e.device(place) = (out_g_e).broadcast(bcast); + break; + default: + PADDLE_THROW("unsupported pooling strategy"); + } } } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 71705cedf2603cd5027321657c2237a5b46ca5f5..581c5145a5bf0181a8b53c3db9170b4413cf5138 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { const int batch_size = logits->dims()[0]; if (context.Attr("softLabel")) { - //(TODO caoying) the forward implementation can be further optimized. + // (TODO caoying) the forward implementation can be further optimized. // Current implementation is exactly cross entropy after softmax. auto prob = EigenMatrix::From(*softmax); auto lbl_mat = EigenMatrix::From(*labels); diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu index afe4d149c53819c45e20353bc9d16393f3f61e0f..53fe505b77bfac8a33803f082f8e935d3ed403b6 100644 --- a/paddle/operators/top_k_op.cu +++ b/paddle/operators/top_k_op.cu @@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. - // TODO(typhoonzero): launch kernel on specified stream. // TODO(typhoonzero): refine this kernel. dim3 threads(256, 1); dim3 grid(input_height, 1); - KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, int(k)); + KeMatrixTopK<<< + grid, threads, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(output_data, output->dims()[1], + indices_data, input_data, + input_width, input_width, int(k)); } }; diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index caa78acd98ea4b35fc69643689cfce23026275e0..895e8d6a63d1fad0ee7a6f5647402435d418b2f1 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -37,6 +38,15 @@ public: real torch_learningRate = optConfig_.learning_method() == "torch_momentum" ? 1.0 - paraConfig.momentum() : 1.0; +#ifdef PADDLE_USE_MKLDNN + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); +#else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM], @@ -44,6 +54,7 @@ public: (firstTime_ ? 1.0 : torch_learningRate), paraConfig.momentum(), applyDecay_ ? paraConfig.decay_rate() : 0); +#endif } virtual void finishBatch() { firstTime_ = false; } }; diff --git a/paddle/parameter/ParameterUpdateFunctions.cpp b/paddle/parameter/ParameterUpdateFunctions.cpp index c8af7105c78dcbf9f625a348b7f38efcf278469e..8b3be062b654a52e667626199be8c8bb4a2a96d7 100644 --- a/paddle/parameter/ParameterUpdateFunctions.cpp +++ b/paddle/parameter/ParameterUpdateFunctions.cpp @@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate, const real* grad, real* momentumVec) { decayRate *= learningRate; +#ifdef PADDLE_USE_MKLDNN +#pragma omp parallel for +#endif for (size_t i = 0; i < size; ++i) { momentumVec[i] = momentum * momentumVec[i] - learningRate * grad[i] - decayRate * value[i]; diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index a106592e454e21c46cd2f87f1bbf6694955d6e23..f6a39a8e26c301296aac0af7f4e8b2c6c97ece24 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -34,13 +34,14 @@ class DeviceContext { template DeviceType* get_eigen_device() const; + + virtual void Wait() const {} }; class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); explicit CPUDeviceContext(CPUPlace place); - virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext { virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ - void Wait() const; + void Wait() const override; /*! \brief Return place in the device context. */ Place GetPlace() const override; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index ed2420b8740e583d307f6836a70fe7e1c780e28b..f0c825bd9b0bc41396b8fdb95f0b4337cbe3db02 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -36,7 +36,7 @@ int GetCurrentDeviceId(); //! Set the GPU device id for next execution. void SetDeviceId(int device_id); -//!Get the memory usage of current GPU device. +//! Get the memory usage of current GPU device. void GpuMemoryUsage(size_t &available, size_t &total); //! Get the maximum allocation size of current GPU device. diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fbe074188e5870de4b00fa4fff733035739974ea..25e290ffbb94354da3393ca0b769aff512d74a41 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -237,7 +237,13 @@ All parameter, weight, gradient are variables in Paddle. return Backward(forwardOp, no_grad_vars).release(); }) .def("infer_shape", &OperatorBase::InferShape) - .def("run", &OperatorBase::Run) + .def("run", + [](OperatorBase &self, + const Scope &scope, + const platform::DeviceContext &dev_ctx) { + self.Run(scope, dev_ctx); + dev_ctx.Wait(); + }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", diff --git a/paddle/trainer/tests/CMakeLists.txt b/paddle/trainer/tests/CMakeLists.txt index f01ad4142d4fe7c7f7d7aac60d967ea114b93e56..066837ca959e46dbe3b39c661aa1bab11cbf2734 100644 --- a/paddle/trainer/tests/CMakeLists.txt +++ b/paddle/trainer/tests/CMakeLists.txt @@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets --config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +################ test_CompareMKLDNNandCPU ###################### +if(WITH_MKLDNN) + add_unittest_without_exec(test_CompareMKLDNNandCPU + test_CompareTwoNets.cpp) + add_test(NAME test_CompareMKLDNNandCPU + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/ + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU + --config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True + --config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False + --use_gpu=False + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +endif() + ############### test_CompareTwoOpts ################### add_unittest_without_exec(test_CompareTwoOpts test_CompareTwoOpts.cpp) diff --git a/paddle/trainer/tests/sample_trainer_config_simple_net.conf b/paddle/trainer/tests/sample_trainer_config_simple_net.conf new file mode 100644 index 0000000000000000000000000000000000000000..77f78161535c49da4ef7fc1563cff58c021aecef --- /dev/null +++ b/paddle/trainer/tests/sample_trainer_config_simple_net.conf @@ -0,0 +1,63 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +################################### Data Configuration ################################### +TrainData(ProtoData(files = "trainer/tests/mnist.list")) +################################### Algorithm Configuration ################################### +settings(batch_size = 1000, + learning_method = MomentumOptimizer(momentum=0.5, sparse=False)) +################################### Network Configuration ################################### +data = data_layer(name ="input", size=784) + +tmp = img_conv_layer(input=data, + num_channels=1, + filter_size=3, + num_filters=32, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=AvgPooling()) + +tmp = img_conv_layer(input=tmp, + filter_size=3, + num_filters=64, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=MaxPooling()) + +tmp = fc_layer(input=tmp, size=64, + bias_attr=True, + act=ReluActivation()) + +output = fc_layer(input=tmp, size=10, + bias_attr=True, + act=SoftmaxActivation()) + +lbl = data_layer(name ="label", size=10) + +cost = classification_cost(input=output, label=lbl) +outputs(cost) diff --git a/paddle/trainer/tests/test_CompareTwoNets.cpp b/paddle/trainer/tests/test_CompareTwoNets.cpp index 94f65e545d116c802fb4877dc14f07aaaf83a4fb..307645d2c3d21d954371fcedb5f95a2536a0183e 100644 --- a/paddle/trainer/tests/test_CompareTwoNets.cpp +++ b/paddle/trainer/tests/test_CompareTwoNets.cpp @@ -26,12 +26,15 @@ DECLARE_int32(gpu_id); DECLARE_bool(local); DECLARE_bool(use_gpu); +DECLARE_bool(use_mkldnn); DECLARE_string(config); DECLARE_string(nics); DEFINE_string(config_file_a, "", "config of one network to compare"); DEFINE_string(config_file_b, "", "config of another network to compare"); +DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run config_file_a"); +DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run config_file_b"); DEFINE_bool(need_high_accuracy, false, "whether need to run in double accuracy"); @@ -128,6 +131,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { matA.getWidth()); } + if (FLAGS_use_mkldnn_a || FLAGS_use_mkldnn_b) { + // some format of mkldnn parameter is different with cpu + // test_MKLDNN will check the parameters + return; + } + vector& parametersA = comDataA.parameters; vector& parametersB = comDataB.parameters; @@ -167,10 +176,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { TEST(Trainer, create) { ComData dataA; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_a; calcGradient(dataA, FLAGS_config_file_a); LOG(INFO) << "\n\nforwardBackward of Network A is finished\n\n"; ComData dataB; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_b; calcGradient(dataB, FLAGS_config_file_b); LOG(INFO) << "\n\nforwardBackward of the Network B is finished\n\n"; diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index c97e6c0a36774caaa4fd8f8130220849975451a0..74025d2a7bb68f87afd24bb4b70ec425ba0dcb64 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None, data = data_layer(name="input", size=1000) - :param name: The name of this layer. It is optional. + :param name: The name of this layer. :type name: basestring :param size: Size of this data layer. :type size: int @@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input, :param param_attr: :param layer_attr: :return: + :rtype: LayerOutput """ if input.size % 3 != 0: raise ValueError("GruStep input size must be divided by 3") diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index f10db783225c07be9ffde25267fdfe096e97ecac..1de514dff487158e0823fd628d9b3b50f36fdd9b 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -4,22 +4,24 @@ from op_test import OpTest class TestCrossEntropyOp1(OpTest): - """Test standard cross-entropy, with index representation of labels. + """Test cross-entropy with discrete one-hot labels. """ def setUp(self): self.op_type = "cross_entropy" batch_size = 30 class_num = 10 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") cross_entropy = np.asmatrix( [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], dtype="float32") + self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} - self.attrs = {'soft_label': False} + self.attrs = {"softLabel": False} def test_check_output(self): self.check_output() @@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp2(OpTest): - """Test soft-label cross-entropy, with vecterized soft labels. + """Test cross-entropy with vectorized soft labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 10 - class_num = 5 + batch_size = 5 + class_num = 37 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = np.random.uniform(0.1, 1.0, @@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest): label /= label.sum(axis=1, keepdims=True) cross_entropy = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) class TestCrossEntropyOp3(OpTest): - """Test one-hot cross-entropy, with vecterized one-hot representation of - labels. + """Test cross-entropy with vectorized one-hot representation of labels. """ def setUp(self): self.op_type = "cross_entropy" - batch_size = 30 - class_num = 10 + batch_size = 5 + class_num = 17 + X = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label_index = np.random.randint( 0, class_num, (batch_size), dtype="int32") label = np.zeros(X.shape) label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], dtype="float32") cross_entropy2 = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") - self.inputs = {'X': X, 'Label': label} - self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': True} + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"softLabel": True} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y", max_relative_error=0.05) if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/framework/tests/test_lstm_unit_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce65bfc31d9fa2d3988759a197e2f497b8161b1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_unit_op.py @@ -0,0 +1,38 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sigmoid_np(x): + return 1. / (1. + np.exp(-x)) + + +def tanh_np(x): + return 2 * sigmoid_np(2. * x) - 1. + + +class LstmUnitTest(OpTest): + def setUp(self): + self.op_type = "lstm_unit" + x_np = np.random.normal(size=(5, 16)).astype("float32") + c_np = np.random.normal(size=(5, 4)).astype("float32") + i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1) + forget_bias_np = 0. + self.attrs = {'forget_bias': 0.} + + new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np( + i_np) * tanh_np(j_np) + new_h = tanh_np(new_c) * sigmoid_np(o_np) + + self.inputs = {'X': x_np, 'C_prev': c_np} + self.outputs = {'C': new_c, 'H': new_h} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b3881cde24c7fb96c3d7f9411352bc62d55077 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMultiplexOp(OpTest): + def setUp(self): + self.op_type = "multiplex" + rows = 3 + index = np.array([3, 1, 0]) + ins1 = np.random.random((rows, 10)).astype("float32") + ins2 = np.random.random((rows, 10)).astype("float32") + ins3 = np.random.random((rows, 10)).astype("float32") + ins4 = np.random.random((rows, 10)).astype("float32") + self.inputs = { + 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), + ('x4', ins4)] + } + # multiplex output + output = np.zeros_like(ins1) + for i in range(0, rows): + k = index[i] + 1 + output[i] = self.inputs['X'][k][1][i] + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out') + + def test_check_grad_ignore_x1(self): + self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1')) + + def test_check_grad_ignore_x1_x2(self): + self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2'])) + + def test_check_grad_ignore_x3(self): + self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 2b6b7db36808a4b68c55328a1eb9ac212c18b678..676fd9f7c555fd5c8544e760345ab954cd137dc5 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -7,6 +7,14 @@ class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" x_np = np.random.normal(size=(10, 10)).astype("float32") + + for pos, val in np.ndenumerate(x_np): + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + while abs(val) < 1e-3: + x_np[pos] = np.random.normal() + val = x_np[pos] + x_np_sign = np.sign(x_np) x_np = x_np_sign * np.maximum(x_np, .005) alpha_np = np.array([.1]) diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index cf864936af6361da1f16df3cfb759b468214b970..211086e5f4de32b996f0fa27c2eb52670c2b1e11 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -3,20 +3,37 @@ import numpy as np from op_test import OpTest -class TestSeqAvgPool1D(OpTest): - def setUp(self): - self.op_type = 'sequence_avg_pool' +class SeqPoolType(OpTest): + AVERAGE = 0 + SUM = 1 + SQRT = 2 + MAX = 3 + LAST = 4 + FIRST = 5 + + +class TestSeqAvgPool(OpTest): + def set_data(self): + self.op_type = 'sequence_pool' # one level, batch size is 4 x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') lod = [[0, 4, 5, 8, 11]] + self.inputs = {'X': (x, lod)} out = np.zeros((4, 23)).astype('float32') + self.outputs = {'Out': out} + + def compute(self): + self.attrs = {'strategy': SeqPoolType.AVERAGE} + x, lod = self.inputs['X'] + out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.mean(axis=0) - self.inputs = {'X': (x, lod)} - self.outputs = {'Out': out} + def setUp(self): + self.set_data() + self.compute() def test_check_output(self): self.check_output() @@ -25,26 +42,44 @@ class TestSeqAvgPool1D(OpTest): self.check_grad(["X"], "Out") -class TestSeqAvgPool2D(OpTest): - def setUp(self): - self.op_type = 'sequence_avg_pool' +class TestSeqAvgPool2D(TestSeqAvgPool): + def set_data(self): + self.op_type = 'sequence_pool' # one level, batch size is 4 x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') lod = [[0, 4, 5, 8, 13]] + self.inputs = {'X': (x, lod)} out = np.zeros((4, 3, 17)).astype('float32') + self.outputs = {'Out': out} + + def compute(self): + self.attrs = {'strategy': SeqPoolType.AVERAGE} + x, lod = self.inputs['X'] + out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) - self.inputs = {'X': (x, lod)} - self.outputs = {'Out': out} - def test_check_output(self): - self.check_output() +class TestSeqSumPool(TestSeqAvgPool): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SUM} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x.sum(axis=0) - def test_check_grad(self): - self.check_grad(["X"], "Out") + +class TestSeqSumPool2D(TestSeqAvgPool2D): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SUM} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) if __name__ == '__main__':