diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 786f224608f7d41c438411de0e09fedbcf2264b8..8b29227cfab2a36d5b9f6d17b837b33da8d2a92e 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -18,12 +18,13 @@ import sys import time import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid as fluid -import paddle.v2.fluid.core as core -import paddle.v2.fluid.profiler as profiler +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.profiler as profiler import argparse import functools import os +from paddle.fluid import debuger def str2bool(v): @@ -182,28 +183,27 @@ def main(): start_time = time.time() num_samples = 0 train_pass_acc.reset() - with profiler.profiler("CPU", 'total') as prof: - for batch_id, data in enumerate(train_reader()): - ts = time.time() - img_data = np.array( - map(lambda x: x[0].reshape(data_shape), data)).astype( - "float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = y_data.reshape([-1, 1]) - - loss, acc, b_size = exe.run( - trainer_prog, - feed={"pixel": img_data, - "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size]) - iters += 1 - num_samples += len(data) - train_pass_acc.add(value=acc, weight=b_size) - print( - "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" - % (pass_id, iters, loss, acc, - len(data) / (time.time() - ts)) - ) # The accuracy is the accumulation of batches, but not the current batch. + for batch_id, data in enumerate(train_reader()): + ts = time.time() + img_data = np.array( + map(lambda x: x[0].reshape(data_shape), data)).astype( + "float32") + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = y_data.reshape([-1, 1]) + + loss, acc, b_size = exe.run( + trainer_prog, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[avg_cost, batch_acc, batch_size]) + iters += 1 + num_samples += len(data) + train_pass_acc.add(value=acc, weight=b_size) + print( + "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" + % (pass_id, iters, loss, acc, + len(data) / (time.time() - ts)) + ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time pass_train_acc = train_pass_acc.eval() @@ -254,9 +254,7 @@ def main(): pserver_prog = t.get_pserver_program(current_endpoint) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) - print("starting server side startup") exe.run(pserver_startup) - print("starting parameter server...") exe.run(pserver_prog) elif training_role == "TRAINER": # Parameter initialization diff --git a/benchmark/cluster/vgg16/vgg16_tf.py b/benchmark/cluster/vgg16/vgg16_tf.py index 996df0e314b867ea8de618dfd3977f490fbe8372..2d220478acae46566760209dbc012cff316946aa 100644 --- a/benchmark/cluster/vgg16/vgg16_tf.py +++ b/benchmark/cluster/vgg16/vgg16_tf.py @@ -292,14 +292,18 @@ def run_benchmark(cluster_spec, server): return np.mean(test_accs) config = tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1, + log_device_placement=True) config.gpu_options.allow_growth = True hooks = [tf.train.StopAtStepHook(last_step=1000000)] with tf.train.MonitoredTrainingSession( - master=server.target, is_chief=(args.task_index == 0), - hooks=hooks) as sess: + master=server.target, + is_chief=(args.task_index == 0), + hooks=hooks, + config=config) as sess: iters, num_samples, start_time = 0, 0, 0.0 for pass_id in range(args.num_passes): # train diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt index da67701ec1af57df742dce105990cffa40f45d7c..a9b27933a5307aabeaf150aeb859e869197229f5 100644 --- a/doc/CMakeLists.txt +++ b/doc/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(v2) +add_subdirectory(fluid) diff --git a/doc/fluid/CMakeLists.txt b/doc/fluid/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cc999f5a8d70a2239ea3b130e9da172d5f681c65 --- /dev/null +++ b/doc/fluid/CMakeLists.txt @@ -0,0 +1,49 @@ +if(NOT DEFINED SPHINX_THEME) + set(SPHINX_THEME default) +endif() + +if(NOT DEFINED SPHINX_THEME_DIR) + set(SPHINX_THEME_DIR) +endif() + +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees") + +# HTML output director +set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.en.in" + "${BINARY_BUILD_DIR_EN}/conf.py" + @ONLY) + +sphinx_add_target(paddle_fluid_docs + html + ${BINARY_BUILD_DIR_EN} + ${SPHINX_CACHE_DIR_EN} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR_EN}) + +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_doctrees") + +# HTML output directory +set(SPHINX_HTML_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/html") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.cn.in" + "${BINARY_BUILD_DIR_CN}/conf.py" + @ONLY) + +sphinx_add_target(paddle_fluid_docs_cn + html + ${BINARY_BUILD_DIR_CN} + ${SPHINX_CACHE_DIR_CN} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR_CN}) diff --git a/doc/fluid/build_and_install/index_cn.rst b/doc/fluid/build_and_install/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..9276236f9fd511bde3570a8c88b437119911d60a --- /dev/null +++ b/doc/fluid/build_and_install/index_cn.rst @@ -0,0 +1,2 @@ +安装与使用 +------------ diff --git a/doc/fluid/build_and_install/index_en.rst b/doc/fluid/build_and_install/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..cc1e61a58a026a0f5c3b106875a8a86dc9cba613 --- /dev/null +++ b/doc/fluid/build_and_install/index_en.rst @@ -0,0 +1,2 @@ +Build and Install +------------ diff --git a/doc/fluid/design/index_cn.rst b/doc/fluid/design/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..f1887be6901653d4263d711d78b626d2abfd45c9 --- /dev/null +++ b/doc/fluid/design/index_cn.rst @@ -0,0 +1,2 @@ +设计思想 +------------ diff --git a/doc/fluid/design/index_en.rst b/doc/fluid/design/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..18a4b4122f6e3f0096676f34ffea8a80aa9b6696 --- /dev/null +++ b/doc/fluid/design/index_en.rst @@ -0,0 +1,2 @@ +Design +------------ diff --git a/doc/fluid/dev/index_cn.rst b/doc/fluid/dev/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..e1edf079fa0f85eb7f6709fd945fffae88625d01 --- /dev/null +++ b/doc/fluid/dev/index_cn.rst @@ -0,0 +1,2 @@ +开发标准 +------------ diff --git a/doc/fluid/dev/index_en.rst b/doc/fluid/dev/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..faf9dfcd315fddc4774c3717b41086fa6c6bf85a --- /dev/null +++ b/doc/fluid/dev/index_en.rst @@ -0,0 +1,4 @@ +Development +------------ + +This is Development page diff --git a/doc/fluid/faq/index_cn.rst b/doc/fluid/faq/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..395c1109891b5a00eab6f0b44d855658def7fdd6 --- /dev/null +++ b/doc/fluid/faq/index_cn.rst @@ -0,0 +1,2 @@ +FAQ +------------ diff --git a/doc/fluid/faq/index_en.rst b/doc/fluid/faq/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..395c1109891b5a00eab6f0b44d855658def7fdd6 --- /dev/null +++ b/doc/fluid/faq/index_en.rst @@ -0,0 +1,2 @@ +FAQ +------------ diff --git a/doc/fluid/getstarted/index_cn.rst b/doc/fluid/getstarted/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..c4d8525f23ee18cb7f41ab2f0d148fc1dcc852b2 --- /dev/null +++ b/doc/fluid/getstarted/index_cn.rst @@ -0,0 +1,4 @@ +新手入门 +------------ + +新手入门 diff --git a/doc/fluid/getstarted/index_en.rst b/doc/fluid/getstarted/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..a4efd05e2fd94ac0e2cbbc8603e6b0261b7e787f --- /dev/null +++ b/doc/fluid/getstarted/index_en.rst @@ -0,0 +1,4 @@ +GET STARTED +------------ + +This is get started page diff --git a/doc/fluid/howto/cluster/fluid_cluster_train_cn.md b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..1b6f767869aaa800c122c8e7a06a1413e48e10e0 --- /dev/null +++ b/doc/fluid/howto/cluster/fluid_cluster_train_cn.md @@ -0,0 +1,145 @@ +# Fluid 分布式版本使用指南 +本篇文章将说明如何在PaddlePaddle Fluid版本下进行分布式训练的配置和执行,以及将单机训练脚本改造成支持集群训练的版本 + +## 准备工作 +* 可用的集群 + + 包含一个或多个计算节点的集群,每一个节点都能够执行PaddlePaddle的训练任务且拥有唯一的IP地址,集群内的所有计算节点可以通过网络相互通信。 +* 安装PaddlePaddle Fluid with Distribution版本 + + 所有的计算节点上均需要按照分布式版本的PaddlePaddle, 在用于GPU等设备的机器上还需要额外安装好相应的驱动程序和CUDA的库。 + + **注意:**当前对外提供的PaddlePaddle版本并不支持分布式,需要通过源码重新编译。编译和安装方法参见[编译和安装指南](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)。 + cmake编译命令中需要将WITH_DISTRIBUTE设置为ON,下面是一个cmake编译指令示例: +``` bash +cmake .. -DWITH_DOC=OFF -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_SWIG_PY=ON -DWITH_PYTHON=ON +``` + +## 更新训练脚本 +这里,我们以[Deep Learing 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)课程中的第一章 fit a line 为例,描述如何将单机训练脚本改造成支持集群训练的版本。 +### 单机训练脚本示例 +```python +import paddle.v2 as paddle +import paddle.fluid as fluid + +x = fluid.layers.data(name='x', shape=[13], dtype='float32') +y_predict = fluid.layers.fc(input=x, size=1, act=None) +y = fluid.layers.data(name='y', shape=[1], dtype='float32') + +cost = fluid.layers.square_error_cost(input=y_predict, label=y) +avg_cost = fluid.layers.mean(x=cost) + +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +sgd_optimizer.minimize(avg_cost) + +BATCH_SIZE = 20 + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) +exe = fluid.Executor(place) + +exe.run(fluid.default_startup_program()) + +PASS_NUM = 100 +for pass_id in range(PASS_NUM): + fluid.io.save_persistables(exe, "./fit_a_line.model/") + fluid.io.load_persistables(exe, "./fit_a_line.model/") + for data in train_reader(): + avg_loss_value, = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost]) + + if avg_loss_value[0] < 10.0: + exit(0) # if avg cost less than 10.0, we think our code is good. +exit(1) +``` + +我们创建了一个简单的全连接神经网络程序,并且通过Fluid的Executor执行了100次迭代,现在我们需要将该单机版本的程序更新为分布式版本的程序。 +### 介绍Parameter Server +在非分布式版本的训练脚本中,只存在Trainer一种角色,它不仅处理常规的计算任务,也处理参数相关的计算、保存和优化任务。在分布式版本的训练过程中,由于存在多个Trainer节点进行同样的数据计算任务,因此需要有一个中心化的节点来统一处理参数相关的保存和分配。在PaddlePaddle中,我们称这样的节点为[Parameter Server](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/dist_train/parameter_server.md) + +**因此,在分布式的Fluid环境中,我们有两个角色需要创建,分别是Parameter Server和Trainer。** + +### 分布式训练 +Fliud专门提供了工具[Distributed Transpiler](https://github.com/PaddlePaddle/Paddle/blob/ba65d54d9d3b41cd3c5171b00f476d4e60133ddb/doc/fluid/design/dist_train/distributed_architecture.md#distributed-transpiler)用于将单机版的训练程序转换为分布式版本的训练程序。工具背后的理念是找出程序的优化算子和梯度参数,将他们分隔为两部分,通过send/recv 操作算子进行连接,优化算子和梯度参数可以在优化器的minimize函数的返回值中获取到。 +```python +optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) +``` +将Distributed Transpiler、优化算子和梯度函数放在一个代码中如下: +```python +... #define the program, cost, and create sgd optimizer + +optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) #get optimize OPs and gradient parameters + +t = fluid.DistributeTranspiler() # create the transpiler instance +# slice the program into 2 pieces with optimizer_ops and gradient parameters list, as well as pserver_endpoints, which is a comma separated list of [IP:PORT] and number of trainers +t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) + +... #create executor + +# in pserver, run this +#current_endpoint here means current pserver IP:PORT you wish to run on +pserver_prog = t.get_pserver_program(current_endpoint) +pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) +exe.run(pserver_startup) +exe.run(pserver_prog) + +# in trainer, run this +... # define data reader +exe.run(fluid.default_startup_program()) +for pass_id in range(100): + for data in train_reader(): + exe.run(t.get_trainer_program()) +``` +### 分布式训练脚本运行说明 +分布式任务的运行需要将表格中说明的多个参数进行赋值: + +| 参数名 | 值类型 | 说明 | 示例 | +|:-------------|:------|:---------------------------------------|:-------------| +| trainer_id | int | 当前训练节点的ID,训练节点ID编号为0 - n-1, n为trainers的值 | 0/1/2/3 | +| pservers | str | parameter server 列表 | 127.0.0.1:6710,127.0.0.1:6711 | +| trainers | int | 训练节点的总个数,>0的数字 | 4 | +| server_endpoint | str | 当前所起的服务节点的IP:PORT | 127.0.0.1:8789 | +| training_role | str | 节点角色, TRAINER/PSERVER | PSERVER | + +**注意:** ```training_role```是用来区分当前所起服务的角色的,用于训练程序中,用户可根据需要自行定义,其他参数为fluid.DistributeTranspiler的transpile函数所需要,需要在调用函数前进行定义,样例如下: + +```python +t = fluid.DistributeTranspiler() +t.transpile( + optimize_ops, + params_grads, + trainer_id, + pservers=pserver, + trainers=trainers) +if training_role == "PSERVER": + pserver_prog = t.get_pserver_program(server_endpoint) + pserver_startup = t.get_startup_program(server_endpoint, pserver_prog) +``` + +### Demo +完整的demo代码位于Fluid的test目录下的[book](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_fit_a_line.py)中。 + +第一步,进入demo代码所在目录: +```bash +cd /paddle/python/paddle/fluid/tests/book +``` + +第二步,启动Parameter Server: +```bash +PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.2 TRAINERS=2 POD_IP=192.168.1.2 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=PSERVER python test_fit_a_line.py +``` +执行命令后请等待出现提示: ```Server listening on 192.168.1.2:6174 ```, 表示Paramter Server已经正常启动。 + +第三步,启动Trainer: +```bash +PADDLE_INIT_PORT=6174 PADDLE_INIT_PSERVERS=192.168.1.3 TRAINERS=2 POD_IP=192.168.1.3 PADDLE_INIT_TRAINER_ID=1 TRAINING_ROLE=TRAINER python test_fit_a_line.py +``` +由于我们定义的Trainer的数量是2个,因此需要在另外一个计算节点上再启动一个Trainer。 + +现在我们就启动了一个包含一个Parameter Server和两个Trainer的分布式训练任务。 diff --git a/doc/fluid/howto/index_cn.rst b/doc/fluid/howto/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..a92abad0c56a4fd821f9a6b9f4f5909504c8aaf1 --- /dev/null +++ b/doc/fluid/howto/index_cn.rst @@ -0,0 +1,2 @@ +进阶使用 +------------ diff --git a/doc/fluid/howto/index_en.rst b/doc/fluid/howto/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..06036bdce554a96443ea1fa47c15f7670ea6089d --- /dev/null +++ b/doc/fluid/howto/index_en.rst @@ -0,0 +1,4 @@ +HOW TO +------------ + +This is how to page diff --git a/doc/fluid/index_cn.rst b/doc/fluid/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..be3bed4393a7346d4f2a53e2c7409ee7165fb5b6 --- /dev/null +++ b/doc/fluid/index_cn.rst @@ -0,0 +1,12 @@ + PaddlePaddle Fluid +========================== + +.. toctree:: + :maxdepth: 1 + + getstarted/index_cn.rst + design/index_cn.rst + build_and_install/index_cn.rst + howto/index_cn.rst + dev/index_cn.rst + faq/index_cn.rst diff --git a/doc/fluid/index_en.rst b/doc/fluid/index_en.rst new file mode 100644 index 0000000000000000000000000000000000000000..87c831420a57b4b9ce77ecf44f7f4d0feec833a6 --- /dev/null +++ b/doc/fluid/index_en.rst @@ -0,0 +1,12 @@ + PaddlePaddle Fluid +========================== + +.. toctree:: + :maxdepth: 1 + + getstarted/index_en.rst + design/index_en.rst + build_and_install/index_en.rst + howto/index_en.rst + dev/index_en.rst + faq/index_en.rst diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index d74c47b981e51f12d99098818c71f3f6ec455d98..ec637658c03ad94624ee9a4f5def6a84387d293e 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -613,3 +613,14 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad, ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); + +REGISTER_OP_CPU_KERNEL(relu, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + relu_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index b2633d017623c3a6a3bab2b416009d6d7c8fc1d4..7709a551dc155e1f3cd2a19a689999608f497beb 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -31,3 +32,16 @@ namespace ops = paddle::operators; ops::grad_functor>); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); + +REGISTER_OP_CUDA_KERNEL( + relu, ops::ActivationKernel>, + ops::ActivationKernel>, + ops::ActivationKernel>); +REGISTER_OP_CUDA_KERNEL( + relu_grad, ops::ActivationGradKernel>, + ops::ActivationGradKernel>); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8f791a6ca81c13a92fd8adf0d1620203bd4cf7d6..b95e793586219b7c413d0c7adb835081874d9363 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -772,7 +772,6 @@ struct SwishGradFunctor : public BaseActivationFunctor { __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c95077fcbdb6b6c0da31f30b795dbe4d7d4fe6fe --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" + +namespace paddle { +namespace operators { + +template <> +void GetAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); + + old_num_accumulates_ = in_old_num_accumulates->data()[0]; + num_accumulates_ = in_num_accumulates->data()[0]; + num_updates_ = in_num_updates->data()[0]; +} + +template <> +void SetAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); + + out_old_num_accumulates->data()[0] = old_num_accumulates_; + out_num_accumulates->data()[0] = num_accumulates_; + out_num_updates->data()[0] = num_updates_; +} + +class AverageAccumulatesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("param"), + "Input (param) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_1"), + "Input (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_2"), + "Input (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_3"), + "Input (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_num_accumulates"), + "Input (in_num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"), + "Input (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_num_updates"), + "Input (num_updates) of average_accumulates op should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_1"), + "Output (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_2"), + "Output (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_3"), + "Output (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"), + "Output (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"), + "Output (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_num_updates"), + "Output (num_updates) of average_accumulates op should not be null."); + + auto in_dim = ctx->GetInputDim("param"); + + ctx->SetOutputDim("out_sum_1", in_dim); + ctx->SetOutputDim("out_sum_2", in_dim); + ctx->SetOutputDim("out_sum_3", in_dim); + ctx->SetOutputDim("out_num_accumulates", {1}); + ctx->SetOutputDim("out_old_num_accumulates", {1}); + ctx->SetOutputDim("out_num_updates", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("param")->type()), + ctx.GetPlace()); + } +}; + +class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("param", "(Tensor), The parameter to be accumulated."); + AddInput("in_sum_1", + "(Tensor), A tensor used to store the parameter " + "sums with the same shape as input(param)."); + AddInput("in_sum_2", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddInput("in_sum_3", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddInput("in_num_accumulates", + "(Tensor), The accumulating times of current window with " + "shape [1]."); + AddInput( + "in_old_num_accumulates", + "(Tensor), The accumulating times of previous window with " + "shape [1]."); + AddInput("in_num_updates", + "(Tensor), The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddOutput("out_sum_1", + "(Tensor), A tensor used to store the " + "parameter sums with the same shape as input(param)."); + AddOutput("out_sum_2", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddOutput("out_sum_3", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddOutput( + "out_num_accumulates", + "(Tensor), The accumulating times of current window with " + "shape [1]."); + AddOutput( + "out_old_num_accumulates", + "(Tensor) The accumulating times of previous window with " + "shape [1]."); + AddOutput( + "out_num_updates", + "(Tensor), The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddAttr("average_window", + "(float, default 0) " + "The rate of average window size relative to num_updates.") + .SetDefault(0); + AddAttr("max_average_window", + "(int64_t) " + "Maximum size of average window. It suggests that the " + "number of mini-batches " + "in one pass is appropriate value to set."); + AddAttr("min_average_window", + "(int64_t, default 10000L) " + "Minimu size of average window.") + .SetDefault(10000L); + + AddComment(R"DOC( +AverageAccumulates Operator. +Accumulate the sum of parameter whtin sliding window. The size of sliding window is +determined by 'average_window', 'max_average_window' and 'min_average_window'. +Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'. +'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'. + +All the accumulators were inited to zero before training. + +And for a mini-batch in training, accumulators were computed as below steps: + num_updates += 1 + num_accumulates += 1 + sum_1 += param + if num_updates % kMaxNumAccumulates == 0: + sum_2 += sum_1 + sum_1 = 0 + if num_accumulates >= min_average_window && num_accumulates >= min(max_average_window, num_updates * average_window): + sum_3 = sum_1 + sum_2 + sum_1 = 0 + sum_2 = 0 + old_num_accumulates = num_accumulates + num_accumulates = 0 + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(average_accumulates, ops::AverageAccumulatesOp, + ops::AverageAccumulatesOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + average_accumulates, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..270c46984465e5ca62eaa8da3955ce7a3eaa0c57 --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +template <> +void GetAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CPUPlace(), &old_num_accumulates_, + platform::CUDAPlace(), in_old_num_accumulates->data(), + sizeof(int64_t), stream); + memory::Copy(platform::CPUPlace(), &num_accumulates_, platform::CUDAPlace(), + in_num_accumulates->data(), sizeof(int64_t), stream); + memory::Copy(platform::CPUPlace(), &num_updates_, platform::CUDAPlace(), + in_num_updates->data(), sizeof(int64_t), stream); +} + +template <> +void SetAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto stream = ctx.cuda_device_context().stream(); + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); + + memory::Copy(platform::CUDAPlace(), out_old_num_accumulates->data(), + platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t), + stream); + memory::Copy(platform::CUDAPlace(), out_num_accumulates->data(), + platform::CPUPlace(), &num_accumulates_, sizeof(int64_t), + stream); + memory::Copy(platform::CUDAPlace(), out_num_updates->data(), + platform::CPUPlace(), &num_updates_, sizeof(int64_t), stream); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + average_accumulates, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f858109d1428dc67d94c253e5a39818eb2d4560d --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +void GetAccumulators(const framework::ExecutionContext& ctx, + int64_t& num_updates, int64_t& num_accumulates, + int64_t& old_num_accumulates); + +template +void SetAccumulators(const framework::ExecutionContext& ctx, + int64_t num_updates, int64_t num_accumulates, + int64_t old_num_accumulates); + +template +class AverageAccumulatesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // It is used to avoid loss of precision + static const int64_t kMaxNumAccumulates = 16384; + // Get accumulators from input + int64_t num_updates = 0; + int64_t num_accumulates = 0; + int64_t old_num_accumulates = 0; + GetAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); + + // Get attrs + float average_window = ctx.Attr("average_window"); + int64_t max_average_window = ctx.Attr("max_average_window"); + int64_t min_average_window = ctx.Attr("min_average_window"); + min_average_window = + std::min(min_average_window, max_average_window); + + // Get inputs + auto* param = ctx.Input("param"); + auto* in_sum_1 = ctx.Input("in_sum_1"); + auto* in_sum_2 = ctx.Input("in_sum_2"); + auto* in_sum_3 = ctx.Input("in_sum_3"); + auto param_tensor = EigenVector::Flatten(*param); + auto in_sum_1_tensor = EigenVector::Flatten(*in_sum_1); + auto in_sum_2_tensor = EigenVector::Flatten(*in_sum_2); + auto in_sum_3_tensor = EigenVector::Flatten(*in_sum_3); + + // Get outputs + auto* out_sum_1 = ctx.Output("out_sum_1"); + auto* out_sum_2 = ctx.Output("out_sum_2"); + auto* out_sum_3 = ctx.Output("out_sum_3"); + auto out_sum_1_tensor = EigenVector::Flatten(*out_sum_1); + auto out_sum_2_tensor = EigenVector::Flatten(*out_sum_2); + auto out_sum_3_tensor = EigenVector::Flatten(*out_sum_3); + + // Compute + auto& place = *ctx.template device_context().eigen_device(); + math::SetConstant constant_functor; + ++num_updates; + ++num_accumulates; + out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor; + out_sum_2_tensor.device(place) = in_sum_2_tensor; + out_sum_3_tensor.device(place) = in_sum_3_tensor; + if (num_updates % kMaxNumAccumulates == 0) { + // Move the sum to a different buffer to avoid loss of precision due to + // too many sums. + out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + } + if (num_accumulates >= min_average_window && + num_accumulates >= std::min(max_average_window, + num_updates * average_window)) { + // Now the average window is too long, discard the old sum. + out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + constant_functor(ctx.template device_context(), out_sum_2, + 0.0); + old_num_accumulates = num_accumulates; + num_accumulates = 0; + } + + // Set accumulators to output + SetAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index ec315695a68befc2e3de798fdb3fa146a903aaff..6da3a24dc89a85fe432b6350d3af7b0e84337c9d 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -78,7 +78,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { for (int64_t i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); int64_t index = i * class_num + label_data[i]; - dx_data[index] = -dy_data[i] / x_data[index]; + dx_data[index] = math::TolerableValue()(-dy_data[i] / x_data[index]); } } } diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 94395ccfbcbd74ee40552a5c70dc8b8063a5f851..2b19f0448955d2d7582f23ac133c14ffdf5c9e49 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,8 @@ if(WITH_DISTRIBUTE) - grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc + grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) + cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr + cares zlib protobuf sendrecvop_grpc) endif() diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h index 099deb12d0e436427c147ab9b1eb553b712e14fb..0cbe514d0498b4775c5cb4ff3e7ff4f968da4180 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.h +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -23,9 +23,107 @@ limitations under the License. */ #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" +namespace grpc { +// A ZeroCopyInputStream that reads from grpc_byte_buffer +class GrpcBufferReader final + : public ::google::protobuf::io::ZeroCopyInputStream { + typedef void (CoreCodegenInterface::*OldReaderInitAPI)( + grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); + typedef int (CoreCodegenInterface::*NewReaderInitAPI)( + grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); + void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + (g_core_codegen_interface->*ptr)(reader, buffer); + } + void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader, + grpc_byte_buffer* buffer) { + int result = (g_core_codegen_interface->*ptr)(reader, buffer); + (void)result; + } + + public: + explicit GrpcBufferReader(grpc_byte_buffer* buffer) + : byte_count_(0), backup_count_(0) { + ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_, + buffer); + } + ~GrpcBufferReader() override { + g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_); + } + + bool Next(const void** data, int* size) override { + if (backup_count_ > 0) { + *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - + backup_count_; + GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); + *size = (int)backup_count_; + backup_count_ = 0; + return true; + } + if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_, + &slice_)) { + return false; + } + g_core_codegen_interface->grpc_slice_unref(slice_); + *data = GRPC_SLICE_START_PTR(slice_); + // On win x64, int is only 32bit + GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); + byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); + return true; + } + + void BackUp(int count) override { backup_count_ = count; } + + bool Skip(int count) override { + const void* data; + int size; + while (Next(&data, &size)) { + if (size >= count) { + BackUp(size - count); + return true; + } + // size < count; + count -= size; + } + // error or we have too large count; + return false; + } + + ::google::protobuf::int64 ByteCount() const override { + return byte_count_ - backup_count_; + } + + private: + int64_t byte_count_; + int64_t backup_count_; + grpc_byte_buffer_reader reader_; + grpc_slice slice_; +}; + +}; // namespace grpc + namespace paddle { namespace operators { namespace detail { +// Source provides a way for a particular RPC implementation to provide +// received data to ParseFrom. +class Source { + public: + virtual ~Source() {} + + // Return the stream that contains the data to be parsed. + // Note that this method might be invoked more than once if + // ParseFrom needs to fall back to a more expensive parsing method. + // Every call must return a stream pointing at the beginning of + // the serialized RecvTensorResponse. + // + // Note that a subsequent call to contents() invalidates previous + // results of contents(). + // + // Ownership of the returned stream is retained by the Source and + // should not be deleted by the caller. + virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0; +}; // A ZeroCopyInputStream that reads from a grpc::ByteBuffer. class GrpcByteBufferSource @@ -46,6 +144,42 @@ class GrpcByteBufferSource ::google::protobuf::int64 byte_count_; }; +class GrpcByteBufferSourceWrapper : public Source { + public: + GrpcByteBufferSourceWrapper(GrpcByteBufferSource* source) : source_(source) {} + virtual ::google::protobuf::io::ZeroCopyInputStream* contents() override { + return source_; + } + + private: + GrpcByteBufferSource* source_; +}; + +class GrpcByteSource : public Source { + public: + explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {} + ~GrpcByteSource() override { DeleteStream(); } + + typedef ::grpc::GrpcBufferReader Reader; + + ::google::protobuf::io::ZeroCopyInputStream* contents() override { + DeleteStream(); + stream_ = new (&space_) Reader(buffer_); + return stream_; + } + + private: + void DeleteStream() { + if (stream_) { + stream_->~Reader(); + } + } + + grpc_byte_buffer* buffer_; // Not owned + Reader* stream_ = nullptr; // Points into space_ if non-nullptr + char space_[sizeof(Reader)]; +}; + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ddeeebec58e02f1686fd2e3d3e5ac1a4c4fd3c59..eb19685aa6cb73862b9e31afbf9c5138659b1b13 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "grpc_client.h" +#include #include "paddle/fluid/framework/threadpool.h" + namespace paddle { namespace operators { namespace detail { @@ -31,8 +33,9 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { auto* var = p_scope->FindVar(var_name_val); - sendrecv::VariableMessage req; - SerializeToMessage(var_name_val, var, *p_ctx, &req); + + ::grpc::ByteBuffer req; + SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); // varhandle VarHandle var_h; @@ -46,8 +49,11 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, s->Prepare(var_h, time_out); s->response_call_back_ = NULL; - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + auto call = std::move(s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, + &cq_)); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, (void*)s); }); req_count_++; @@ -56,9 +62,19 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, } void ProcGetResponse(const VarHandle& var_h, - const sendrecv::VariableMessage& ret_msg) { - auto* outvar = var_h.scope->FindVar(var_h.name); - DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); + // const sendrecv::VariableMessage& ret_msg) { + const ::grpc::ByteBuffer& ret_msg) { + framework::Variable* outvar = NULL; + DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar); +} + +template +void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { + ::grpc::Slice slice(proto.ByteSizeLong()); + proto.SerializeWithCachedSizesToArray( + const_cast(reinterpret_cast(slice.begin()))); + ::grpc::ByteBuffer tmp(&slice, 1); + result->Swap(&tmp); } bool RPCClient::AsyncGetVariable(const std::string& ep, @@ -88,8 +104,13 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); + + auto call = std::move(s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_)); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, (void*)s); }); req_count_++; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index f520367dd981288416631fdad15241fb5d811d07..8216ac52fbbb3dcd2f30957cde58a850a77b08d6 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -25,6 +25,11 @@ limitations under the License. */ #include #include +#include +#include +#include +#include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -49,15 +54,11 @@ struct VarHandle { } }; -void ProcGetResponse(const VarHandle& var_h, - const sendrecv::VariableMessage& msg); +void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { - stub_ = sendrecv::SendRecvService::NewStub(ch); - context_ = NULL; - } + explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; } virtual ~BaseProcessor() {} @@ -82,19 +83,18 @@ class BaseProcessor { virtual void Process() = 0; - std::unique_ptr stub_; std::unique_ptr context_; grpc::Status status_; VarHandle var_h_; }; -typedef std::function +typedef std::function RequestSendCallBack; class SendProcessor : public BaseProcessor { public: explicit SendProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch), stub_g_(ch) {} virtual ~SendProcessor() {} @@ -104,17 +104,18 @@ class SendProcessor : public BaseProcessor { } } - sendrecv::VoidMessage reply_; + ::grpc::GenericStub stub_g_; + ::grpc::ByteBuffer reply_; RequestSendCallBack response_call_back_ = NULL; }; -typedef std::function +typedef std::function RequestGetCallBack; class GetProcessor : public BaseProcessor { public: explicit GetProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch), stub_g_(ch) {} virtual ~GetProcessor() {} @@ -124,30 +125,37 @@ class GetProcessor : public BaseProcessor { } } - sendrecv::VariableMessage reply_; + ::grpc::ByteBuffer reply_; + ::grpc::GenericStub stub_g_; RequestGetCallBack response_call_back_ = ProcGetResponse; }; class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } virtual ~BatchBarrierProcessor() {} virtual void Process() {} sendrecv::VoidMessage reply_; + std::unique_ptr stub_; }; class FetchBarrierProcessor : public BaseProcessor { public: explicit FetchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) {} + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } virtual ~FetchBarrierProcessor() {} virtual void Process() {} sendrecv::VariableMessage reply_; + std::unique_ptr stub_; }; class RPCClient { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 8fff430cc4890925e4edba2fadb8eb7fc647d181..9691d1e86b111def5b82e022dd01795aaf5c7b0d 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_server.h" -using grpc::ServerAsyncResponseWriter; +using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { @@ -26,9 +26,10 @@ enum CallStatus { PROCESS = 0, FINISH }; // https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server class RequestBase { public: - explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq) - : service_(service), cq_(cq), status_(PROCESS) { + explicit RequestBase(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + const platform::DeviceContext* dev_ctx) + : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} @@ -42,55 +43,58 @@ class RequestBase { } protected: - grpc::ServerContext ctx_; - sendrecv::SendRecvService::AsyncService* service_; - grpc::ServerCompletionQueue* cq_; + ::grpc::ServerContext ctx_; + GrpcService::AsyncService* service_; + ::grpc::ServerCompletionQueue* cq_; CallStatus status_; + const platform::DeviceContext* dev_ctx_; }; -typedef std::pair MessageWithName; - class RequestSend final : public RequestBase { public: - explicit RequestSend(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq, - SimpleBlockQueue* queue) - : RequestBase(service, cq), queue_(queue), responder_(&ctx_) { - service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_, - this); + explicit RequestSend(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, ReceivedQueue* queue, + const platform::DeviceContext* dev_ctx) + : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { + request_.reset(new VariableResponse(scope, dev_ctx_)); + int method_id = static_cast(detail::GrpcMethod::kSendVariable); + service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, + cq_, cq_, this); } virtual ~RequestSend() {} - virtual std::string GetReqName() { return request_.varname(); } + virtual std::string GetReqName() { return request_->Varname(); } virtual void Process() { - MessageWithName msg_with_name = - std::make_pair(request_.varname(), std::move(request_)); - queue_->Push(std::move(msg_with_name)); - responder_.Finish(reply_, grpc::Status::OK, this); + queue_->Push(std::make_pair(request_->Varname(), request_)); + + sendrecv::VoidMessage reply; + responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; } protected: - sendrecv::VariableMessage request_; - sendrecv::VoidMessage reply_; - SimpleBlockQueue* queue_; + std::shared_ptr request_; + ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; }; class RequestGet final : public RequestBase { public: - explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq, framework::Scope* scope, + explicit RequestGet(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + framework::Scope* scope, const platform::DeviceContext* dev_ctx, SimpleBlockQueue* queue) - : RequestBase(service, cq), + : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope), - dev_ctx_(dev_ctx), queue_(queue) { - service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); + int method_id = static_cast(detail::GrpcMethod::kGetVariable); + service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, + cq_, this); } virtual ~RequestGet() {} @@ -101,24 +105,26 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); + + ::grpc::ByteBuffer reply; if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); } - // TODO(gongwb): check var's info. - responder_.Finish(reply_, grpc::Status::OK, this); + + responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; - MessageWithName msg_with_name = - // request name reply - std::make_pair(var_name, std::move(reply_)); - queue_->Push(msg_with_name); + + if (var_name == FETCH_BARRIER_MESSAGE) { + sendrecv::VariableMessage msg; + MessageWithName msg_with_name = std::make_pair(var_name, msg); + queue_->Push(msg_with_name); + } } protected: sendrecv::VariableMessage request_; - sendrecv::VariableMessage reply_; - ServerAsyncResponseWriter responder_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; - const platform::DeviceContext* dev_ctx_; SimpleBlockQueue* queue_; }; @@ -133,8 +139,8 @@ void AsyncGRPCServer::WaitClientGet(int count) { } void AsyncGRPCServer::RunSyncUpdate() { - grpc::ServerBuilder builder; - builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials()); builder.SetMaxSendMessageSize(std::numeric_limits::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.RegisterService(&service_); @@ -182,8 +188,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { if (is_shut_down_) { return; } - RequestSend* send = - new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); + RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, + &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); } @@ -198,7 +204,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { } // FIXME(typhoonzero): change cq_name to enum. -void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, +void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { TryToRegisterNewOne(); diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index b6666bcf96e484b0b17b935c0efb2930f19b19f2..9c21a07432031a6e4ac03fda357ff6bbff618418 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -14,28 +14,35 @@ limitations under the License. */ #pragma once +#include +#include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" -#include -#include -#include -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/grpc_service.h" + +//#include namespace paddle { namespace operators { namespace detail { +typedef std::pair> + ReceivedMessage; +typedef SimpleBlockQueue ReceivedQueue; + typedef std::pair MessageWithName; class RequestBase; -class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { +class AsyncGRPCServer final { public: explicit AsyncGRPCServer(const std::string &address) : address_(address) {} @@ -50,14 +57,16 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } - const MessageWithName Get() { return this->var_recv_queue_.Pop(); } + const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } - void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } + void Push(const std::string &msg_name) { + this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); + } void ShutDown(); protected: - void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, + void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); @@ -66,18 +75,19 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { private: std::mutex cq_mutex_; volatile bool is_shut_down_ = false; - std::unique_ptr cq_send_; - std::unique_ptr cq_get_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; - sendrecv::SendRecvService::AsyncService service_; - std::unique_ptr server_; + GrpcService::AsyncService service_; + std::unique_ptr<::grpc::Server> server_; std::string address_; framework::Scope *scope_; const platform::DeviceContext *dev_ctx_; + // received variable from RPC, operators fetch variable from this queue. - SimpleBlockQueue var_recv_queue_; SimpleBlockQueue var_get_queue_; + ReceivedQueue var_recv_queue_; // condition of the sub program std::mutex barrier_mutex_; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h new file mode 100644 index 0000000000000000000000000000000000000000..ae6f9db3bd31a4b4839b34e8e53dd87f1ecf4b1d --- /dev/null +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -0,0 +1,118 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/operators/detail/variable_response.h" + +// NOTE: This method was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// method and did some modifications so that we can parse gRPC +// requests without too much copying of the tensor data. + +namespace grpc { +class CompletionQueue; +class Channel; +class RpcService; +class ServerCompletionQueue; +class ServerContext; + +// Support parsing/unparsing of tensorflow::VariableResponse. +// Wire-format is identical to RecvVariableResponse. +template <> +class SerializationTraits { + public: + static Status Serialize( + const paddle::operators::detail::VariableResponse& msg, + grpc_byte_buffer** bp, bool* own_buffer) { + PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); + return Status(); + } + static Status Deserialize(grpc_byte_buffer* buffer, + paddle::operators::detail::VariableResponse* msg, + int max_message_size = INT_MAX) { + if (buffer == nullptr) { + return Status(StatusCode::INTERNAL, "No payload"); + } + + Status result = g_core_codegen_interface->ok(); + if (result.ok()) { + paddle::operators::detail::GrpcByteSource source(buffer); + int ret = msg->Parse(&source); + if (ret != 0) { + result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); + } + } + g_core_codegen_interface->grpc_byte_buffer_destroy(buffer); + return result; + } +}; +} // namespace grpc + +namespace paddle { +namespace operators { +namespace detail { + +enum class GrpcMethod { + kSendVariable, + kGetVariable, +}; + +static const int kGrpcNumMethods = + static_cast(GrpcMethod::kGetVariable) + 1; + +inline const char* GrpcMethodName(GrpcMethod id) { + switch (id) { + case GrpcMethod::kSendVariable: + return "/sendrecv.SendRecvService/SendVariable"; + case GrpcMethod::kGetVariable: + return "/sendrecv.SendRecvService/GetVariable"; + } + + // Shouldn't be reached. + PADDLE_ENFORCE(false, "Invalid id: not found valid method name"); + return nullptr; +} + +class GrpcService final { + public: + class AsyncService : public ::grpc::Service { + public: + AsyncService() { + for (int i = 0; i < kGrpcNumMethods; ++i) { + AddMethod(new ::grpc::internal::RpcServiceMethod( + GrpcMethodName(static_cast(i)), + ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); + ::grpc::Service::MarkMethodAsync(i); + } + } + virtual ~AsyncService() {} + + // Make RequestAsyncUnary public for grpc_call.h + using ::grpc::Service::RequestAsyncUnary; + }; +}; + +} // namespace detail +} // namespace operator +} // namespace paddle diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index b0215d4a80c9440f09c35434903fd6166b03e8b0..598aaa4c51a6c5cd32eeffe08bbae849aee1a1df 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -32,6 +32,9 @@ enum VarType { SELECTED_ROWS = 1; } +// NOTICE(gongwb):don't modify this proto if you are not +// not familar with how we serialize in sendrecvop_utils.h +// and deserilize it in variable_response.h. message VariableMessage { enum Type { // Pod Types @@ -45,7 +48,6 @@ message VariableMessage { } message LodData { repeated int64 lod_data = 1; } - string varname = 1; // TODO(Yancey1989): reference framework::proto::VarDesc::VarType VarType type = 2; @@ -64,3 +66,5 @@ message VariableMessage { } message VoidMessage {} + +message TestMessage { int64 test_1 = 1; } diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 39117eeeb611b025c426938c60ddf82c6af232ca..d7bbf79c50651943d91c38bbaab775f5ee8dc395 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -13,61 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include +#include #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h" +#include "paddle/fluid/operators/detail/variable_response.h" namespace paddle { namespace operators { namespace detail { -void SerializeToMessage(const std::string& name, const framework::Variable* var, - const platform::DeviceContext& ctx, - sendrecv::VariableMessage* msg) { - msg->set_varname(name); - std::ostringstream oss; - switch (framework::ToVarType(var->Type())) { - case framework::proto::VarType_Type_LOD_TENSOR: - msg->set_type(sendrecv::VarType::LOD_TENSOR); - framework::SerializeToStream(oss, var->Get(), ctx); - break; - case framework::proto::VarType_Type_SELECTED_ROWS: - msg->set_type(sendrecv::VarType::SELECTED_ROWS); - framework::SerializeToStream(oss, var->Get(), - ctx); - break; - default: { - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } - msg->set_serialized(oss.str()); -} - -void DeserializeFromMessage(const sendrecv::VariableMessage& msg, - const platform::DeviceContext& ctx, - framework::Variable* var) { - std::istringstream iss(msg.serialized()); - switch (msg.type()) { - case sendrecv::VarType::LOD_TENSOR: - DeserializeFromStream(iss, var->GetMutable(), ctx); - break; - case sendrecv::VarType::SELECTED_ROWS: { - DeserializeFromStream(iss, var->GetMutable(), - ctx); - break; - } - default: { - PADDLE_THROW("Deserialize does not support type: %s", - typeid(var->Type()).name()); - break; - } - } -} - void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg) { @@ -123,6 +81,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, static_cast(ctx); auto copy_size = tensor.memory_size(); payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, boost::get(tensor.place()), reinterpret_cast(tensor.data()), @@ -132,6 +91,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, platform::CPUPlace cpu; memory::Free(cpu, backing); }; + #endif } else { payload = tensor.data(); @@ -219,80 +179,11 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, - framework::Variable* var) { - sendrecv::VariableMessage meta; - GrpcByteBufferSource source; - source.Init(msg); - ::google::protobuf::io::CodedInputStream input(&source); - // do zerocopy parsing - PADDLE_ENFORCE(meta.ParseFromCodedStream(&input)); - PADDLE_ENFORCE(input.ConsumedEntireMessage()); - // dims is needed by both tensor and selectedrows - std::vector vecdims; - for (auto& d : meta.dims()) { - vecdims.push_back(d); - } - framework::DDim dims = framework::make_ddim(vecdims); - - if (meta.type() == sendrecv::LOD_TENSOR) { - auto* tensor = var->GetMutable(); - tensor->Resize(dims); - void* tensor_data = tensor->mutable_data( - ctx.GetPlace(), - paddle::operators::detail::ToTypeIndex(meta.data_type())); - framework::LoD lod; - for (int i = 0; i < meta.lod_level(); ++i) { - framework::Vector v; - for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) { - v.push_back(meta.lod(i).lod_data(j)); - } - lod.push_back(v); - } - tensor->set_lod(lod); - // How to avoid copying and use the message buffer directly? - // Maybe need to find a way to release all memory except tensor content. - if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - platform::CPUPlace cpu; - auto& gpu_dev_ctx = static_cast(ctx); - memory::Copy(boost::get(tensor->place()), - tensor_data, cpu, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size(), gpu_dev_ctx.stream()); - ctx.Wait(); -#endif - } else { - memcpy(tensor_data, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size()); - } - } else if (meta.type() == sendrecv::SELECTED_ROWS) { - auto* slr = var->GetMutable(); - auto* tensor = slr->mutable_value(); - int64_t* rows_data = slr->mutable_rows()->data(); - tensor->Resize(dims); - void* tensor_data = tensor->mutable_data( - ctx.GetPlace(), - paddle::operators::detail::ToTypeIndex(meta.data_type())); - if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA - platform::CPUPlace cpu; - auto& gpu_dev_ctx = static_cast(ctx); - memory::Copy(boost::get(tensor->place()), - tensor_data, cpu, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size(), gpu_dev_ctx.stream()); - ctx.Wait(); -#endif - } else { - memcpy(tensor_data, - reinterpret_cast(meta.serialized().data()), - meta.serialized().size()); - } - // copy rows CPU data, GPU data will be copied lazly - memcpy(rows_data, reinterpret_cast(meta.rows().data()), - meta.rows().size()); - } + const framework::Scope* scope, + framework::Variable*& var) { + operators::detail::VariableResponse resp(scope, &ctx); + PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); + var = resp.GetVar(); } } // namespace detail diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 4fa6aefd3e0b1bd45ac52b1eff3b29126d79f03a..3b875627032a6b08cc70280b3cc825c2a703923f 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" @@ -36,21 +37,14 @@ namespace detail { typedef void (*DestroyCallback)(void*); -void SerializeToMessage(const std::string& name, const framework::Variable* var, - const platform::DeviceContext& ctx, - sendrecv::VariableMessage* msg); - -void DeserializeFromMessage(const sendrecv::VariableMessage& msg, - const platform::DeviceContext& ctx, - framework::Variable* var); - void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, ::grpc::ByteBuffer* msg); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, - framework::Variable* var); + const framework::Scope* scope, + framework::Variable*& var); inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { switch (type) { diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc index 2f06e5a686b996858d21930a1afa2861efca4a9b..4be5963794717e55bd03110996ad511e6fa0a1db 100644 --- a/paddle/fluid/operators/detail/test_serde.cc +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -16,11 +16,13 @@ limitations under the License. */ #include #include +#include #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" @@ -31,19 +33,21 @@ namespace operators = paddle::operators; namespace math = paddle::operators::math; namespace memory = paddle::memory; -void RunSerdeTestTensor(platform::Place place) { - // serialize var to ByteBuffer - framework::Variable var; - auto* tensor = var.GetMutable(); - tensor->Resize(framework::make_ddim({4, 8, 4, 2})); - framework::LoD lod; - lod.push_back(framework::Vector({1, 3, 8})); - tensor->set_lod(lod); - int tensor_numel = 4 * 8 * 4 * 2; +void RunSerdeTestSelectedRows(platform::Place place) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + + // serialize var to ByteBuffer + framework::Variable var; + auto* slr = var.GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + tensor->Resize(framework::make_ddim({2, 10})); tensor->mutable_data(place); - math::set_constant(ctx, tensor, 31.9); + int tensor_numel = 2 * 10; + math::set_constant(ctx, tensor, 32.7); + rows->push_back(3); + rows->push_back(10); ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); @@ -56,62 +60,67 @@ void RunSerdeTestTensor(platform::Place place) { for (const auto& s : slices) { tmp.append(reinterpret_cast(s.begin()), s.size()); } + sendrecv::VariableMessage varmsg; EXPECT_TRUE(varmsg.ParseFromString(tmp)); + EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 0); - EXPECT_EQ(varmsg.dims()[0], 4); - EXPECT_EQ(varmsg.dims()[1], 8); - EXPECT_EQ(varmsg.dims()[2], 4); - EXPECT_EQ(varmsg.dims()[3], 2); - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + EXPECT_EQ(varmsg.type(), 1); const float* tensor_data = reinterpret_cast(varmsg.serialized().data()); + const int64_t* rows_data = + reinterpret_cast(varmsg.rows().data()); for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 31.9); + EXPECT_FLOAT_EQ(tensor_data[i], 32.7); } - + EXPECT_EQ(rows_data[0], 3); + EXPECT_EQ(rows_data[1], 10); // deserialize zero-copy - framework::Variable var2; - operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); - auto tensor2 = var2.Get(); + // framework::Variable var2; + // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + framework::Scope scope; + scope.Var("myvar"); + operators::detail::TensorResponse resp(&scope, &ctx); + EXPECT_EQ(resp.Parse(msg), 0); + + framework::Variable* var2 = resp.GetVar(); + + auto* slr2 = var2->GetMutable(); + auto* tensor2 = slr2->mutable_value(); + auto* rows2 = slr2->mutable_rows(); float* tensor_data2 = nullptr; framework::Tensor tmp_tensor; if (platform::is_gpu_place(ctx.GetPlace())) { platform::CPUPlace cpu; - framework::TensorCopy(tensor2, cpu, &tmp_tensor); + framework::TensorCopy(*tensor2, cpu, &tmp_tensor); tensor_data2 = tmp_tensor.data(); } else { - tensor_data2 = const_cast(tensor2.data()); + tensor_data2 = const_cast(tensor2->data()); } + const int64_t* rows_data2 = rows2->data(); - EXPECT_EQ(varmsg.lod_level(), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); - EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); - EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); - for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); + } + EXPECT_EQ(rows_data2[0], 3); + EXPECT_EQ(rows_data2[1], 10); } -void RunSerdeTestSelectedRows(platform::Place place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - +void RunTestLodTensor(platform::Place place, int from_type = 0) { // serialize var to ByteBuffer framework::Variable var; - auto* slr = var.GetMutable(); - auto* tensor = slr->mutable_value(); - auto* rows = slr->mutable_rows(); - tensor->Resize(framework::make_ddim({2, 10})); + auto* tensor = var.GetMutable(); + tensor->Resize(framework::make_ddim({4, 8, 4, 2})); + framework::LoD lod; + lod.push_back(framework::Vector({1, 3, 8})); + tensor->set_lod(lod); + int tensor_numel = 4 * 8 * 4 * 2; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); tensor->mutable_data(place); - int tensor_numel = 2 * 10; - math::set_constant(ctx, tensor, 32.7); - rows->push_back(3); - rows->push_back(10); + math::set_constant(ctx, tensor, 31.9); ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); @@ -126,43 +135,75 @@ void RunSerdeTestSelectedRows(platform::Place place) { } sendrecv::VariableMessage varmsg; EXPECT_TRUE(varmsg.ParseFromString(tmp)); - EXPECT_EQ(varmsg.varname(), "myvar"); - EXPECT_EQ(varmsg.type(), 1); + EXPECT_EQ(varmsg.type(), 0); + EXPECT_EQ(varmsg.dims()[0], 4); + EXPECT_EQ(varmsg.dims()[1], 8); + EXPECT_EQ(varmsg.dims()[2], 4); + EXPECT_EQ(varmsg.dims()[3], 2); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); const float* tensor_data = reinterpret_cast(varmsg.serialized().data()); - const int64_t* rows_data = - reinterpret_cast(varmsg.rows().data()); for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data[i], 32.7); + EXPECT_FLOAT_EQ(tensor_data[i], 31.9); } - EXPECT_EQ(rows_data[0], 3); - EXPECT_EQ(rows_data[1], 10); + + // message binary + std::string str; + varmsg.SerializeToString(&str); + + // message bytebuffer + ::grpc::Slice slices_2[1]; + int num_slices = 1; + slices_2[0] = ::grpc::Slice(str.length()); + memcpy(const_cast(slices_2[0].begin()), str.c_str(), str.length()); + ::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices); + // deserialize zero-copy - framework::Variable var2; - operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + framework::Scope scope; + scope.Var("myvar"); + operators::detail::TensorResponse resp(&scope, &ctx); + if (from_type == 0) { + EXPECT_EQ(resp.Parse(msg), 0); + } else { + EXPECT_EQ(resp.Parse(bytebuffer2), 0); + } - auto* slr2 = var2.GetMutable(); - auto* tensor2 = slr2->mutable_value(); - auto* rows2 = slr2->mutable_rows(); + framework::Variable* var2 = resp.GetVar(); + + auto tensor2 = var2->Get(); float* tensor_data2 = nullptr; framework::Tensor tmp_tensor; if (platform::is_gpu_place(ctx.GetPlace())) { platform::CPUPlace cpu; - framework::TensorCopy(*tensor2, cpu, &tmp_tensor); + framework::TensorCopy(tensor2, cpu, &tmp_tensor); tensor_data2 = tmp_tensor.data(); } else { - tensor_data2 = const_cast(tensor2->data()); + tensor_data2 = const_cast(tensor2.data()); } - const int64_t* rows_data2 = rows2->data(); - for (int i = 0; i < tensor_numel; ++i) { - EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); - } - EXPECT_EQ(rows_data2[0], 3); - EXPECT_EQ(rows_data2[1], 10); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9); +} + +TEST(LodTensor, GPU) { + platform::CUDAPlace place; + RunTestLodTensor(place); + RunTestLodTensor(place, 1); +} + +TEST(LodTensor, CPU) { + platform::CPUPlace place; + RunTestLodTensor(place); + RunTestLodTensor(place, 1); } TEST(SelectedRows, CPU) { @@ -174,13 +215,3 @@ TEST(SelectedRows, GPU) { platform::CUDAPlace place; RunSerdeTestSelectedRows(place); } - -TEST(Tensor, CPU) { - platform::CPUPlace place; - RunSerdeTestTensor(place); -} - -TEST(Tensor, GPU) { - platform::CUDAPlace place; - RunSerdeTestTensor(place); -} \ No newline at end of file diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc new file mode 100644 index 0000000000000000000000000000000000000000..12e8eb0b4da2252b104415aef4156bf100c3e565 --- /dev/null +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -0,0 +1,400 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/variable_response.h" +#include +#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +enum WireType { + WIRETYPE_VARINT = 0, + WIRETYPE_LENGTH_DELIMITED = 2, +}; + +inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; } + +inline WireType GetTagWireType(uint32_t tag) { + return static_cast(tag & 0x7); +} + +bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input, + int* result) { + uint64_t v; + if (input->ReadVarint64(&v) && v <= static_cast(INT_MAX)) { + *result = static_cast(v); + return true; + } else { + return false; + } +} + +bool ReadRaw(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& dev_ctx, platform::Place place, + void* dest, int size) { + const void* data = NULL; + int size_to_write = 0; + + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + auto& gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + + char* p = reinterpret_cast(dest); + while (size > 0) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + + memory::Copy(boost::get(place), + reinterpret_cast(p), cpu, data, size_to_write, + gpu_dev_ctx.stream()); + p += size_to_write; + size -= size_to_write; + + input->Skip(size_to_write); + } + gpu_dev_ctx.Wait(); +#else + PADDLE_THROW("Unexpected branch"); +#endif + return true; + } + + char* p = reinterpret_cast(dest); + while (size > 0) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + // TODO(gongwb): can we avoid copy? + platform::CPUPlace cpu; + memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write); + + p += size_to_write; + size -= size_to_write; + + input->Skip(size_to_write); + } + + return true; +} + +bool VariableResponse::CopyLodTensorData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + + framework::LoD lod; + for (int i = 0; i < meta_.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) { + v.push_back(meta_.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); + + if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { + return false; + } + + return true; +} + +inline framework::DDim GetDims( + const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) { + std::vector vecdims; + for (auto& d : dims) { + vecdims.push_back(d); + } + return framework::make_ddim(vecdims); +} + +bool VariableResponse::CopySelectRowsTensorData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta_.data_type())); + + if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { + return false; + } + + return true; +} + +bool VariableResponse::CopySelectRowsData( + ::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, int length) { + auto var = scope_->FindVar(meta_.varname()); + auto* slr = var->GetMutable(); + int64_t* rows_data = slr->mutable_rows()->data(); + + // copy rows CPU data, GPU data will be copied lazily. + platform::CPUPlace cpu; + if (!ReadRaw(input, ctx, cpu, rows_data, length)) { + return false; + } + + return true; +} + +bool ParseLodData(::google::protobuf::io::CodedInputStream* input, + std::vector* lod) { + while (true) { + auto p = input->ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + + if (!p.second) { + return (tag == 0); + } + + switch (tag) { + case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: { + uint64_t v; + if (wt == WIRETYPE_VARINT) { + if (!input->ReadVarint64(&v)) { + return false; + } + lod->push_back(v); + break; + } + + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int length = 0; + if (!input->ReadVarintSizeAsInt(&length)) { + return tag; + } + + for (int i = 0; i < length; i++) { + uint64_t v; + if (!input->ReadVarint64(&v)) { + return false; + } + lod->push_back(v); + } + break; + } + + return false; + } + default: { return false; } + } + } + + return true; +} + +int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) { + GrpcByteBufferSource source; + source.Init(byte_buffer); + GrpcByteBufferSourceWrapper r(&source); + + return Parse(&r); +} + +int VariableResponse::Parse(Source* source) { + ::google::protobuf::io::ZeroCopyInputStream* input_stream = + source->contents(); + ::google::protobuf::io::CodedInputStream input(input_stream); + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + + while (true) { + auto p = input.ReadTagWithCutoff(127); + int tag = GetTagFieldNumber(p.first); + WireType wt = GetTagWireType(p.first); + if (!p.second) { + if (tag != 0) { + return -1; + } + + return 0; + } + + switch (tag) { + case sendrecv::VariableMessage::kVarnameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_varname(temp); + break; + } + case sendrecv::VariableMessage::kTypeFieldNumber: { + uint64_t v; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + + meta_.set_type(static_cast<::sendrecv::VarType>(v)); + break; + } + case sendrecv::VariableMessage::kDataTypeFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + + meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v)); + break; + } + case sendrecv::VariableMessage::kDimsFieldNumber: { + // not packed + if (wt == WIRETYPE_VARINT) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + break; + } + + // packed + if (wt == WIRETYPE_LENGTH_DELIMITED) { + int length = 0; + if (!input.ReadVarintSizeAsInt(&length)) { + return tag; + } + for (int i = 0; i < length; i++) { + uint64_t v; + if (!input.ReadVarint64(&v)) { + return tag; + } + meta_.add_dims(v); + } + break; + } + + return tag; + } + case sendrecv::VariableMessage::kLodLevelFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + meta_.set_lod_level(static_cast(v)); + break; + } + case sendrecv::VariableMessage::kLodFieldNumber: { + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p = + input.IncrementRecursionDepthAndPushLimit(length); + + std::vector lod_data; + if (p.second < 0 || !ParseLodData(&input, &lod_data)) { + return tag; + } + + if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { + return false; + } + + if (lod_data.size() == 0) { + break; + } + + auto lod = meta_.add_lod(); + for (uint32_t i = 0; i < lod_data.size(); i++) { + lod->add_lod_data(lod_data[i]); + } + break; + } + case sendrecv::VariableMessage::kSerializedFieldNumber: { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR) && + meta_.varname() != "", + "meta info should be got first!"); + + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + framework::DDim dims = GetDims(meta_.dims()); + if (meta_.type() == sendrecv::LOD_TENSOR) { + PADDLE_ENFORCE(meta_.lod_size() >= 0, + "lod info should be got first!"); + if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) { + return tag; + } + break; + } + + if (meta_.type() == sendrecv::SELECTED_ROWS) { + if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) { + return tag; + } + break; + } + + return tag; + } + case sendrecv::VariableMessage::kRowsFieldNumber: { + PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || + meta_.type() == sendrecv::LOD_TENSOR) && + meta_.varname() != "", + "meta info should be got first!"); + + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + + if (!CopySelectRowsData(&input, *dev_ctx_, length)) { + return tag; + } + break; + } + + default: { + // Unknown tag, return unknown error. + return -1; + } + } + } + + return 0; +} + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h new file mode 100644 index 0000000000000000000000000000000000000000..c7bc7a46e7bc8a24ae9b4931d347767c10277d22 --- /dev/null +++ b/paddle/fluid/operators/detail/variable_response.h @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" + +#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/detail/send_recv.pb.h" + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/detail/bytebuffer_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +class VariableResponse { + public: + VariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx) + : scope_(scope), dev_ctx_(dev_ctx){}; + + virtual ~VariableResponse(){}; + + // return: + // 0:ok. + // -1: unkown error. + // other: number of error field. + int Parse(Source* source); + + // return: + // 0:ok. + // -1: unkown error. + // other: number of error field. + int Parse(const ::grpc::ByteBuffer& byte_buffer); + + inline std::string Varname() { return meta_.varname(); } + + // should call parse first. + framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } + + private: + bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, + framework::DDim& dims, int length); + + bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, int length); + + bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input, + const platform::DeviceContext& ctx, + framework::DDim& dims, int length); + + private: + const framework::Scope* scope_; + const platform::DeviceContext* dev_ctx_; + // only Skeleton + sendrecv::VariableMessage meta_; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a594de67e05acd28ffedc5407beecfaea1281444..31ea2a7e581950b5399a7a5efc9ae38b8ea3c52d 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -69,9 +69,7 @@ class ListenAndServOp : public framework::OperatorBase { } void Stop() override { - detail::MessageWithName term_msg; - term_msg.first = LISTEN_TERMINATE_MESSAGE; - rpc_service_->Push(term_msg); + rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->ShutDown(); server_thread_->join(); } @@ -108,7 +106,7 @@ class ListenAndServOp : public framework::OperatorBase { size_t recv_var_cnt = 0; int batch_barrier = 0; while (batch_barrier != fan_in) { - const detail::MessageWithName &v = rpc_service_->Get(); + const detail::ReceivedMessage v = rpc_service_->Get(); auto recv_var_name = v.first; if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; @@ -121,12 +119,11 @@ class ListenAndServOp : public framework::OperatorBase { } else { VLOG(3) << "received grad: " << recv_var_name; recv_var_cnt++; - auto *var = recv_scope.FindVar(recv_var_name); + auto var = v.second->GetVar(); if (var == nullptr) { LOG(ERROR) << "Can not find server side var: " << recv_var_name; PADDLE_THROW("Can not find server side var"); } - detail::DeserializeFromMessage(v.second, dev_ctx, var); if (var->IsType()) { sparse_vars.push_back(var); } diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 34ea6a91ce7743462d378cf471a5ec3a12ca51d1..5518ebed3f792a5acdfbb27976bc2c6dbd78069a 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor::operator()( XGrad->mutable_data(context.GetPlace()))); } +template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 47cb336d87f8627d86ac33d6ac32c04d5d93f753..5596fa0648ccc151bc0d11de9c556599428a8d71 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace, - ops::SoftmaxCUDNNKernel); -REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, + ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel); +REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf0244e8662e827a90d8472a097315680579ff6d --- /dev/null +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +#include + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; + +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::primitive; +using mkldnn::softmax_forward; +using mkldnn::prop_kind; +using mkldnn::stream; + +template +class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + auto mkldnn_engine = dev_ctx.GetEngine(); + const Tensor* input = ctx.Input("X"); + Tensor* output = ctx.Output("Out"); + PADDLE_ENFORCE(input->dims().size() == 2UL, + "The input of softmax op must be a 2D matrix."); + const T* input_data = input->data(); + // allocate memory for output + T* output_data = output->mutable_data(ctx.GetPlace()); + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + // MKL-DNN does support softmax over selected axis. Having 2D Tensor, + // we will make normalization after final eg. axis: 1 + PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])), + "Softmax input and output dimensions should match"); + // Same memory descriptor to be used for input and output + memory::dims softmax_tz = {src_tz[0], src_tz[1]}; + // Currently only supports NC data format + // TODO(jczaja-intel): support more formats + auto softmax_md = + MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc); + // Normalization is made after innermost dimension eg. C out of NC + auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, + softmax_md, 1 /*dim: C*/); + // create memory primitives + auto softmax_src_memory = + memory({softmax_md, mkldnn_engine}, (void*)input_data); + auto softmax_dst_memory = + memory({softmax_md, mkldnn_engine}, (void*)output_data); + auto softmax_prim_desc = + softmax_forward::primitive_desc(softmax_desc, mkldnn_engine); + auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory, + softmax_dst_memory); + std::vector pipeline{softmax}; + stream(stream::kind::eager).submit(pipeline).wait(); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace, + ops::SoftmaxMKLDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 1b63f8a499e5d20d2f10c3cd1024d1bcf78764d4..e2c0f915d96b7746191572fa27b725d90cb6e2e5 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -13,7 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -38,26 +44,32 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - bool use_cudnn = ctx.Attr("use_cudnn"); - bool runtime_cudnn_support = false; + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = - ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } #endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { - library_ = framework::LibraryType::kCUDNN; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } +#endif + + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, + "float16 can only be used when CUDNN is used"); + } + std::string data_format = ctx.Attr("data_format"); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - framework::StringToDataLayout(data_format), library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::StringToDataLayout(data_format), + library_); } }; - class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -77,6 +89,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Softmax Operator. @@ -119,19 +134,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - bool use_cudnn = ctx.Attr("use_cudnn"); - bool runtime_cudnn_support = false; + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = - ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; - } -#endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } +#endif std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 98b4178177b0a8bafd6fe34a92be2a07a2fbc5a7..59b76a1edb5ec5900520fbccb6a6f8f6e7a70aa4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -10,43 +10,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" +#include #include "paddle/fluid/memory/memory.h" - namespace paddle { namespace platform { DeviceContextPool* DeviceContextPool::pool = nullptr; -const platform::DeviceContext* DeviceContextPool::Get( - const platform::Place& place) { +platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( "'Place' is not supported, Please re-compile with WITH_GPU " "option"); } - return it->second; + return it->second.get(); } DeviceContextPool::DeviceContextPool( const std::vector& places) { PADDLE_ENFORCE_GT(places.size(), 0); - for (size_t i = 0; i < places.size(); i++) { - if (platform::is_cpu_place(places[i])) { + using PtrType = std::unique_ptr; + std::unordered_set set; + for (auto& p : places) { + set.insert(p); + } + + for (auto& p : set) { + if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN - device_contexts_.emplace(places[i], - new platform::MKLDNNDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new MKLDNNDeviceContext(boost::get(p)))); #else - device_contexts_.emplace(places[i], - new platform::CPUDeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CPUDeviceContext(boost::get(p)))); #endif - } else if (platform::is_gpu_place(places[i])) { + } else if (platform::is_gpu_place(p)) { #ifdef PADDLE_WITH_CUDA - device_contexts_.emplace(places[i], - new platform::CUDADeviceContext( - boost::get(places[i]))); + device_contexts_.emplace( + p, PtrType(new CUDADeviceContext(boost::get(p)))); #else PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " @@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { + std::lock_guard guard(mutex_); PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 603b890af13b529c490c29112a73a09cc815d07a..202394c7be7e103a609dd0999fc883c794ef0edd 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; + mutable std::mutex mutex_; cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; @@ -159,7 +160,7 @@ class DeviceContextPool { } /*! \brief Return handle of single device context. */ - const platform::DeviceContext* Get(const platform::Place& place); + platform::DeviceContext* Get(const platform::Place& place); template const typename DefaultDeviceContextType::TYPE* GetByPlace( @@ -172,19 +173,8 @@ class DeviceContextPool { private: static DeviceContextPool* pool; - constexpr static int LEFT_SHIFT = 8; - struct Hash { - std::hash hash_; - size_t operator()(const platform::Place& place) const { - int pre_hash = place.which() << LEFT_SHIFT; - if (platform::is_gpu_place(place)) { - pre_hash += boost::get(place).GetDeviceId(); - } - return hash_(pre_hash); - } - }; - std::unordered_map + std::unordered_map, PlaceHash> device_contexts_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 501bddfc6ec8b5d0bf554b0911c32e47fd51ec15..4cc8b377b8b671eb5a446ecbae21ba9628fbd2c8 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -65,6 +65,18 @@ bool is_cpu_place(const Place &); bool places_are_same_class(const Place &, const Place &); bool is_same_place(const Place &, const Place &); +struct PlaceHash { + std::size_t operator()(const Place &p) const { + constexpr size_t num_dev_bits = 4; + std::hash ihash; + size_t dev_id = 0; + if (is_gpu_place(p)) { + dev_id = boost::get(p).device; + } + return ihash(dev_id << num_dev_bits | p.which()); + } +}; + std::ostream &operator<<(std::ostream &, const Place &); template diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 6be2bd8fad9e33cf4e1dcafdd6b8f39111bdbe88..2e9b088bfa596d86c3c43c09d360da772fb2775a 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -35,7 +35,7 @@ function cmake_gen() { -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} ${PYTHON_FLAGS} -DWITH_DSO=ON - -DWITH_DOC=OFF + -DWITH_DOC=${WITH_DOC:-OFF} -DWITH_GPU=${WITH_GPU:-OFF} -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} -DWITH_MKL=${WITH_MKL:-ON} @@ -60,7 +60,7 @@ EOF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} \ ${PYTHON_FLAGS} \ -DWITH_DSO=ON \ - -DWITH_DOC=OFF \ + -DWITH_DOC=${WITH_DOC:-OFF} \ -DWITH_GPU=${WITH_GPU:-OFF} \ -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \ -DWITH_MKL=${WITH_MKL:-ON} \ @@ -231,7 +231,7 @@ gen_capi_package gen_fluid_inference_lib if [[ ${WITH_C_API:-OFF} == "ON" ]]; then - printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n" + printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n" else printf "If you need to install PaddlePaddle in develop docker image," printf "please make install or pip install build/python/dist/*.whl.\n" diff --git a/paddle/scripts/tools/build_docs/.gitignore b/paddle/scripts/tools/build_docs/.gitignore deleted file mode 100644 index 6ec14c8f5bc3774a81dbe87c44f458594b38f12c..0000000000000000000000000000000000000000 --- a/paddle/scripts/tools/build_docs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -doc -doc_cn diff --git a/paddle/scripts/tools/build_docs/build_docs.sh b/paddle/scripts/tools/build_docs/build_docs.sh deleted file mode 100755 index f9bc8bf63ae9afdfca1ff660bc83e62e71f03005..0000000000000000000000000000000000000000 --- a/paddle/scripts/tools/build_docs/build_docs.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -docker run --rm \ - -v $(git rev-parse --show-toplevel):/paddle \ - -e "WITH_GPU=OFF" \ - -e "WITH_AVX=ON" \ - -e "WITH_DOC=ON" \ - -e "WOBOQ=ON" \ - ${1:-"paddlepaddle/paddle:latest-dev"} diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index 0fc4981a8e9da09f15e6d0a5e5c6761e01328876..3e4292d23550b853ea73de787a1c053e1f2c80fd 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0): return channel -def channel_send(channel, value, copy=False): +def channel_send(channel, value, is_copy=False): """ Sends a value through a channel variable. Used by an unbuffered or buffered channel to pass data from within or to a concurrent Go block, where @@ -141,8 +141,8 @@ def channel_send(channel, value, copy=False): channel (Variable|Channel): Channel variable created using `make_channel`. value (Variable): Value to send to channel - copy (bool): Copy data while channel send. If False, then data - is moved. The input cannot be used after move. + is_copy (bool): Copy data while channel send. If False, then data + is moved. The input cannot be used after move. (default False) Returns: Variable: The boolean status on whether or not the channel successfully sent the passed value. @@ -166,7 +166,7 @@ def channel_send(channel, value, copy=False): X = value - if copy is True: + if is_copy is True: copied_X = helper.create_variable( name=unique_name.generate(value.name + '_copy'), type=value.type, diff --git a/python/paddle/fluid/debuger.py b/python/paddle/fluid/debuger.py index 97fa182c4007cc730c06e9f95259a2509e01ecdf..7b4afa9bf65e1369329cd4648c1f5c4bd8fa8357 100644 --- a/python/paddle/fluid/debuger.py +++ b/python/paddle/fluid/debuger.py @@ -16,7 +16,6 @@ import sys import re from graphviz import GraphPreviewGenerator import proto.framework_pb2 as framework_pb2 -import paddle.fluid.core as core _vartype2str_ = [ "UNK", @@ -126,7 +125,6 @@ def pprint_block_codes(block_desc, show_backward=False): def is_var_backward(var_desc): return "@GRAD" in var_desc.name - #print(type(block_desc)) if type(block_desc) is not framework_pb2.BlockDesc: block_desc = framework_pb2.BlockDesc.FromString( block_desc.serialize_to_string()) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index ad655ee96cee0744e7bedb17163faf7d8d1d8877..33cea96421bf93f1693bc06e7412b561f1bd2a32 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -20,6 +20,7 @@ from layer_helper import LayerHelper from distributed_spliter import * import math from . import core +import debuger class VarBlock: @@ -289,6 +290,7 @@ class DistributeTranspiler: dtype=v.dtype, shape=v.shape) recv_inputs.append(var) + # step3 optimize_block = pserver_program.create_block(0) # step 4 diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 70ecffd910a46570b5a8e576d88039fa5e22e726..3e78788f470556d2196b5104f69a0a3285543ec4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -918,6 +918,24 @@ class Block(object): name=v.name) self.vars[new_p.name] = new_p + def clone_variable(self, var): + """ + Clone a variable into current block. + Args: + var: the variable to be cloned. + + Returns: + The new variable cloned from 'var' in current block. + """ + assert isinstance(var, Variable) + return self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + class Program(object): def __init__(self): @@ -960,14 +978,14 @@ class Program(object): """Clone the Program object Set for_test to False when we want to clone the program for training. - Set for_test to True when we want to clone the program for testing. + Set for_test to True when we want to clone the program for testing. Args: for_test(bool): Some operators, such as batch_norm and drop_out ops, behave differently in training and testing. If for_test is True, the is_test attributes in these operators will be set to True for - testing purposes, otherwise, they remain unchanged. - + testing purposes, otherwise, they remain unchanged. + Returns(Program): The cloned Program object. """ diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index da7e74c901e1f5be709c5f9d73f048bfda0c5549..58b668227168c5c5e080f3928035ad98303bbae9 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -399,6 +399,9 @@ class LayerHelper(object): if isinstance(act, basestring): act = {'type': act} tmp = self.create_tmp_variable(dtype=input_var.dtype) + + if 'use_mkldnn' in self.kwargs: + act['use_mkldnn'] = self.kwargs.get('use_mkldnn') act_type = act.pop('type') self.append_op( type=act_type, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2ce68f95057f7820d7ab59ba2b41171c7ecd3654..679de6ce2aa67abe1322702fcb371eded0130698 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -82,6 +82,7 @@ def fc(input, num_flatten_dims=1, param_attr=None, bias_attr=None, + use_mkldnn=False, act=None, name=None): """ @@ -163,8 +164,11 @@ def fc(input, inputs={"X": input_var, "Y": w}, outputs={"Out": tmp}, - attrs={"x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1}) + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + 'use_mkldnn': use_mkldnn + }) mul_results.append(tmp) # sum diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index d7bad221c5fa7b18137bf317125195267437a644..f5c6b47d243dcf4ba985cfb41fc23b44d3ed809f 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -69,6 +69,7 @@ __all__ = [ 'gaussian_random_batch_size_like', 'cumsum', 'scatter', + 'sum', ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a33760a528f667b7afabafa19762eca7d1ef0635..180575c35dc6e115e11cccf9fff9fb2d3cd7e9a6 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict - +from paddle.fluid.framework import Program import framework import layers from backward import append_backward @@ -23,9 +23,11 @@ from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback +from contextlib import contextmanager __all__ = [ - 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Adadelta' + 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', + 'Adadelta', 'ModelAverage' ] @@ -121,7 +123,12 @@ class Optimizer(object): """ pass - def _add_accumulator(self, name, param, dtype=None, fill_value=0.0): + def _add_accumulator(self, + name, + param, + dtype=None, + fill_value=0.0, + shape=None): """Utility function to add an accumulator for a parameter Args: @@ -135,17 +142,19 @@ class Optimizer(object): param.name in self._accumulators[name]): raise Exception("Accumulator {} already exists for parameter {}". format(name, param.name)) - + if shape == None: + shape = param.shape assert isinstance(self.helper, LayerHelper) var = self.helper.create_global_variable( name=unique_name.generate(name), persistable=True, dtype=dtype or param.dtype, type=param.type, - shape=param.shape) + shape=shape) self.helper.set_variable_initializer( var, initializer=Constant(value=float(fill_value))) self._accumulators[name][param.name] = var + return var def _get_accumulator(self, name, param): """Utility function to fetch an accumulator for a parameter @@ -797,3 +806,143 @@ Adamax = AdamaxOptimizer DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer + + +class ModelAverage(Optimizer): + """Accumulate the average of parameters whtin sliding window. The average + result will be saved in temporary variables which can be applied to + parameter variables of current model by calling 'apply()' method. And the + 'restore()' method is used to restored the parameter values of current model. + + The size of average window is determined by average_window_rate, + min_average_window, max_average_window and current update times. + + Args: + params_grads: A list of parameter-grad variable pairs. + average_window_rate: The rate of average window. + min_average_window: The minimum size of average window. + max_average_window: The maximum size of average window. + + Examples: + ... + optimizer = fluid.optimizer.Momentum() + _, params_grads = optimizer.minimize(cost) + model_average = fluid.optimizer.ModelAverage(params_grads, 0.15, + min_average_window=10000, + max_average_window=20000) + for pass_id in range(args.pass_num): + for data in train_reader(): + exe.run(fluid.default_main_program()...) + + with model_average.apply(exe): + for data in test_reader(): + exe.run(inference_program...) + """ + + def __init__(self, + params_grads, + average_window_rate, + min_average_window=10000, + max_average_window=10000, + **kwargs): + super(ModelAverage, self).__init__(0.0, **kwargs) + self.average_window = average_window_rate + self.min_average_window = min_average_window + self.max_average_window = max_average_window + self.params_grads = params_grads + for param, grad in self.params_grads: + if grad is not None: + self._append_average_accumulate_op(param) + + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + for param_grad in self.params_grads: + if param_grad[1] is not None: + self._add_average_apply_op(block, param_grad) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param_grad in self.params_grads: + if param_grad[1] is not None: + self._add_average_restore_op(block, param_grad) + + def _add_average_apply_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + sum_1 = block.clone_variable(self._get_accumulator('sum_1', param)) + sum_2 = block.clone_variable(self._get_accumulator('sum_2', param)) + sum_3 = block.clone_variable(self._get_accumulator('sum_3', param)) + num_accumulates = block.clone_variable( + self._get_accumulator('num_accumulates', param)) + old_num_accumulates = block.clone_variable( + self._get_accumulator('old_num_accumulates', param)) + num_updates = block.clone_variable( + self._get_accumulator('num_updates', param)) + # backup param value to grad + layers.assign(input=param, output=grad) + # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates) + tmp = layers.sum(x=[num_accumulates, old_num_accumulates]) + sum = layers.sum(x=[sum_1, sum_2, sum_3]) + tmp = layers.cast(x=tmp, dtype='float32') + sum = layers.cast(x=sum, dtype='float32') + layers.elementwise_div(x=sum, y=tmp, out=param) + + def _add_average_restore_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + layers.assign(input=grad, output=param) + + def _append_average_accumulate_op(self, param): + self.helper = LayerHelper("average_accumulate") + sum_1 = self._add_accumulator('sum_1', param) + sum_2 = self._add_accumulator('sum_2', param) + sum_3 = self._add_accumulator('sum_3', param) + num_accumulates = self._add_accumulator( + 'num_accumulates', param, dtype='int64', shape=[1]) + old_num_accumulates = self._add_accumulator( + 'old_num_accumulates', param, dtype='int64', shape=[1]) + num_updates = self._add_accumulator( + 'num_updates', param, dtype='int64', shape=[1]) + + self.helper.append_op( + type='average_accumulates', + inputs={ + "param": param, + "in_sum_1": sum_1, + "in_sum_2": sum_2, + "in_sum_3": sum_3, + "in_num_accumulates": num_accumulates, + "in_old_num_accumulates": old_num_accumulates, + "in_num_updates": num_updates + }, + outputs={ + "out_sum_1": sum_1, + "out_sum_2": sum_2, + "out_sum_3": sum_3, + "out_num_accumulates": num_accumulates, + "out_old_num_accumulates": old_num_accumulates, + "out_num_updates": num_updates, + }, + attrs={ + "average_window": self.average_window, + "min_average_window": self.min_average_window, + "max_average_window": self.max_average_window, + }) + + @contextmanager + def apply(self, executor, need_restore=True): + """Apply average values to parameters of current model. + """ + executor.run(self.apply_program) + try: + yield + finally: + if need_restore: + self.restore(executor) + + def restore(self, executor): + """Restore parameter values of current model. + """ + executor.run(self.restore_program) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index eab41ebe711bd21bdc3b34ca83ab57388cc35ba2..1e3decfbaf0691e912b96b415b68353e626cf51e 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest from scipy.special import expit @@ -212,18 +213,39 @@ class TestRound(OpTest): class TestRelu(OpTest): def setUp(self): self.op_type = "relu" - x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) # The same reason with TestAbs x[np.abs(x) < 0.005] = 0.02 - self.inputs = {'X': x} - self.outputs = {'Out': np.maximum(self.inputs['X'], 0)} + out = np.maximum(x, 0) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad(['X'], 'Out', max_relative_error=0.007) + def init_dtype(self): + pass + + +class TestFP16Relu(TestRelu): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + class TestBRelu(OpTest): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 4f20da2b926823db9e7ec92c95178b6d3d1feec9..33d60c7e31ce0817ad26ea1c1c974339936052d3 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -29,15 +29,20 @@ class TestSoftmaxOp(OpTest): def setUp(self): self.op_type = "softmax" self.use_cudnn = False - self.inputs = { - 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") - } - self.outputs = { - 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) + self.use_mkldnn = False + self.dtype = np.float32 + self.init_kernel_type() + + x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype) + out = np.apply_along_axis(stable_softmax, 1, x) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = { + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn } - self.attrs = {'use_cudnn': self.use_cudnn, } - def init_op_type(self): + def init_kernel_type(self): pass def test_check_output(self): @@ -48,6 +53,8 @@ class TestSoftmaxOp(OpTest): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( @@ -57,8 +64,25 @@ class TestSoftmaxOp(OpTest): class TestSoftmaxCUDNNOp(TestSoftmaxOp): - def init_op_type(self): + def init_kernel_type(self): + self.use_cudnn = True + + +class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): + def init_kernel_type(self): self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + +class TestSoftmaxMKLDNNOp(TestSoftmaxOp): + def init_kernel_type(self): + self.use_mkldnn = True if __name__ == "__main__": diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index eac2cb316835fda0a52ac9895eaa80914d0f1e5b..3684d1e8f73a21d9c6f2a5985f8b40ed6984057b 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2747,17 +2747,17 @@ def img_pool_layer(input, .. math:: - w & = 1 + \\frac{ceil(input\_width + 2 * padding - pool\_size)}{stride} + w & = 1 + ceil(\\frac{input\_width + 2 * padding - pool\_size}{stride}) - h & = 1 + \\frac{ceil(input\_height + 2 * padding\_y - pool\_size\_y)}{stride\_y} + h & = 1 + ceil(\\frac{input\_height + 2 * padding\_y - pool\_size\_y}{stride\_y}) - ceil_mode=False: .. math:: - w & = 1 + \\frac{floor(input\_width + 2 * padding - pool\_size)}{stride} + w & = 1 + floor(\\frac{input\_width + 2 * padding - pool\_size}{stride}) - h & = 1 + \\frac{floor(input\_height + 2 * padding\_y - pool\_size\_y)}{stride\_y} + h & = 1 + floor(\\frac{input\_height + 2 * padding\_y - pool\_size\_y}{stride\_y}) The example usage is: