diff --git a/benchmark/IntelOptimizedPaddle.md b/benchmark/IntelOptimizedPaddle.md index 6cc9598947acbdacfbf4c4379987bab8ed7611b0..084d3237d9cfe9ca4837f77cf5f70a2449cfcc03 100644 --- a/benchmark/IntelOptimizedPaddle.md +++ b/benchmark/IntelOptimizedPaddle.md @@ -93,6 +93,15 @@ Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz | MKLML | 22.74 | 41.56 | 81.22 | 133.47 | 210.53 | | MKL-DNN | 175.10 | 272.92 | 450.70 | 512.00 | 600.94 | +- Alexnet + +| BatchSize | 1 | 2 | 4 | 8 | 16 | +|-----------|--------|--------|--------|--------|--------| +| OpenBLAS | | | | | | +| MKLML | 21.32 | 36.55 | 73.06 | 131.15 | 192.77 | +| MKL-DNN | 442.91 | 656.41 | 719.10 | 847.68 | 850.51 | + +chart TBD ### Laptop TBD diff --git a/benchmark/paddle/image/alexnet.py b/benchmark/paddle/image/alexnet.py index 77d130ae34059d1e87040d00346ac1dadd86b0d8..cad6051f1413a5bb95f87a940f3aa81e49e5d282 100644 --- a/benchmark/paddle/image/alexnet.py +++ b/benchmark/paddle/image/alexnet.py @@ -19,7 +19,11 @@ args = { 'num_samples': num_samples } define_py_data_sources2( - "train.list", None, module="provider", obj="process", args=args) + "train.list" if not is_infer else None, + "test.list" if is_infer else None, + module="provider", + obj="process", + args=args) settings( batch_size=batch_size, diff --git a/benchmark/paddle/image/run_openblas_infer.sh b/benchmark/paddle/image/run_openblas_infer.sh index da034f3b9dff794e22086a5295ad2b0c2361c356..71a49231a5527ebee9f45d5f4650ce2a4f6a1c31 100755 --- a/benchmark/paddle/image/run_openblas_infer.sh +++ b/benchmark/paddle/image/run_openblas_infer.sh @@ -8,15 +8,19 @@ function clock_to_seconds() { } function infer() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY topology=$1 layer_num=$2 bs=$3 - thread=`nproc` - if [ $thread -gt $bs ]; then - thread=$bs + trainers=`nproc` + if [ $trainers -gt $bs ]; then + trainers=$bs fi - log="logs/infer-${topology}-${layer_num}-${thread}openblas-${bs}.log" + log="logs/infer-${topology}-${layer_num}-${trainers}openblas-${bs}.log" + threads=$((`nproc` / trainers)) + if [ $threads -eq 0 ]; then + threads=1 + fi + export OPENBLAS_NUM_THREADS=$threads models_in="models/${topology}-${layer_num}/pass-00000/" if [ ! -d $models_in ]; then @@ -28,7 +32,7 @@ function infer() { --config="${topology}.py" \ --use_mkldnn=False \ --use_gpu=False \ - --trainer_count=$thread \ + --trainer_count=$trainers \ --log_period=$log_period \ --config_args="batch_size=${bs},layer_num=${layer_num},is_infer=True,num_samples=256" \ --init_model_path=$models_in \ diff --git a/benchmark/paddle/image/run_openblas_train.sh b/benchmark/paddle/image/run_openblas_train.sh index e9df83fee2a3f796b7234b39619364f6ee4d5dc9..935cff6f2c97d25d6de556cfee25e27dbe49b5b6 100755 --- a/benchmark/paddle/image/run_openblas_train.sh +++ b/benchmark/paddle/image/run_openblas_train.sh @@ -1,7 +1,7 @@ set -e function train() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY + export OPENBLAS_NUM_THREADS=1 topology=$1 layer_num=$2 bs=$3 diff --git a/doc/design/backward.md b/doc/design/backward.md new file mode 100644 index 0000000000000000000000000000000000000000..35f03692bb052e0a04db18d28f6f8d901215a553 --- /dev/null +++ b/doc/design/backward.md @@ -0,0 +1,156 @@ +# Backward Building + +## Motivation + +In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. However, when configuring the model structure, users do not need to define the backward part. So a mechanism is required by the framework which can complete the model's backward part automatically according to the given forward part. + +When implementing a specific `op`, the developer is also asked to implement its backward version, called `grad_op`. A `grad_op` takes gradients of its corresponding `op`'s outputs, and calculate gradients of the `op`'s inputs. During the building of a model's backward part, the framework creates each forward `op`'s `grad_op`, and then string them together in reverse order of forwarding part. In this way, gradients spread from the end to the beginning of the model, in another word, from the loss to parameters. + +## Challenges + +The motivation of backward building is apparent. However, implementation it correctly is not so easy. In the **Fluid** design, a deep learning model is described by `Program`, `Block`, `Op` and `Variable`. The `Block` itself can be nested. It means that the `op`s and `variable`s are scattered across different blocks rather than all be gathered in a single graph. Our backward building algorithm shall visit blocks in recursive order and be able to insert `grad_op`s and new created `variable`s into the right place. + +## Usage + +Although the whole algorithm is comprised of many functions, only one is exposed as API: + +```python +def append_backward(loss, parameter_list=None, no_grad_set=None): + """ + Append backward part to main_program + + Args: + loss(Variable): The variable generated by the cost function. + parameter_list(list): Parameters that need to be updated by optimizers. + If None, it means all parameters need to be updated. + + no_grad_set(set): Variables that have no gradients in Block 0. + If None, the set will be generated inside the function and + contains all variables with `step_gradient=True` from all blocks. + + Return: + (list[Variable]): list of (parameters, gradients) pair. + """ +``` + +By invoking this API, the framework appends backward part of the program where the `loss` is. It takes three arguments. `loss` means the final loss value. It must be a scalar and is usually the output of the loss layer. It is also where the gradient generated and backpropagation starts. `parameter_list` marks all parameters needs updating. If it's `None`, all parameter will be updated by optimizers. `no_grad_set` marks variables without gradient. if all outputs of some `grad_op` are in `no_grad_set`, the `grad_op` will not be run. + +This API will be invoked automatically before optimizer building. +As a result, in most cases, users do not need to invoke the API by themselves to append backward part. + +## Implementation + +The implementation of backward building algorithm is in `backward.py` file. The whole algorithm can be divided into two independent parts: creating `grad_op`s and creating new variables. + +### Creating `grad_op`s + +The creating of `grad_op`s is implemented by: + +```python +def _append_backward_ops_(target, + block, + target_block, + no_grad_dict, + grad_to_var): + """ + Create all grad ops, and insert them into given block + + Args: + target(Variable): the target variable of forward pass + block(Block): the block where forward ops are + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of varibale names. These varibales have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + """ +``` + +Given a `block`, the function will traverses all `op`s in this block in reverse order, gets corresponding `grad_op` from the C++ core via `core.get_grad_op_desc()`, then append it to `target_block`. + +However, some specific `op`(e.g. `while_op`, `if_else_op`) can hold its own sub-block. For these sub-blocks contains `op`s as well, the `grad_op` creating should be recursive. + +During the reverse traversal, we check each `op` whether it has an attribute named `sub_block`. If so, it means there is a sub-block and we need to deal with it first. After creating a new block whose father is the one in `op`'s attribute, we invoke `_append_backward_ops_()` recursively, assigning the new block to parameter `target_block` and the one in `op`'s attribute to `block`. The *pseudo-code* shows this process: + +``` +******* pseudo-code ******** +for op in reversed(block.ops): + if op has an attribute named 'sub_block': + Get the sub-block(`s_block`) from op's attribute. + Create a new block(`grad_s_block`), whose father is `s_block`. + Invoke _append_backward_ops_(), with `block=s_block` and `target_block=grad_s_block` + + Invoke `core.get_grad_op_desc()` to get op's grad_op. + Insert name correspondings between variables and their gradients of the grad_op to grad_to_var + Assign grad_s_block to grad_op as it's 'sub_block' attribute. + Append grad_op to current target_block. +``` + +The first invoking of `_append_backward_ops_()` is initiated by `append_backward()`, in which parameters `block` and `target_block` are all assigned with root block(the block with index 0). + +### Corner Cases of `grad_op` Creating + +In the previous section, we show the regular process of `grad_op` creating. However, in some corner cases, the conventional algorithm is not enough to get the correct result and appending handling is required. These additional processes run after the algorithm mentioned above and do some special adjusts on its output `grad_op`s. + +#### Shared Variables + +If a variable is read by more than one `op` in the forward pass, its gradient is likely to be written by more than one `grad_op`s in the next backward pass. To make the gradient result being the sum of all `grad_op`s' outputs instead of the last running one, we assign each output with a temporary variable and then add a `sum_op` to add them up. + +For the debug convenience, if the final gradient name is `w@GRAD`, it's corresponding temporary variables will be named as `w@GRAD@RENAME@0`, `w@GRAD@RENAME@1`... + +See function `_addup_repetitive_outputs_` in `backward.py` for implementation details. + +#### No Gradient Variables + +In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass. + +But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. + +This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). + +### Creating Backward Variables + +Up to now, we have completed all creating and adjusting jobs of `grad_op`s. However, backward variables have not been created. Now they are only represented by `grad_op`'s input and output arguments. The backward variable creating job will be done by: + +```python +def _append_backward_vars_(block, + start_op_idx, + grad_to_var, + grad_info_map): + """ + Create new variables required by backward pass. + + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + """ +``` + +Given a `block`, this function traverses all the `grad_op`s in it(The argument `start_op_idx` indicates where the grad_op sequence starts.) and creates all the uncreated outputs. The *pseudo-code* shows this process: + +``` +for op in block.ops[start_op_idx : ]: + + if op has an attribute named 'sub_block': + Get the sub-block(`s_block`) from op's attribute. + Invoke _append_backward_vars_(), with `block=s_block` + + for var_name in op.all_output_names(): + if block.has_var_recursive(var_name) or var_name is the name of empty variable: + continue + create a new variable named 'var_name' in block + if grad_to_var.has_key(var_name): + set grad_info_map[grad_to_var[var_name]] as a tuple of (var_name. block) + + do op's var type inference + do op's shape inference +``` diff --git a/paddle/framework/images/duplicate_op.graffle b/doc/design/images/duplicate_op.graffle similarity index 100% rename from paddle/framework/images/duplicate_op.graffle rename to doc/design/images/duplicate_op.graffle diff --git a/paddle/framework/images/duplicate_op.png b/doc/design/images/duplicate_op.png similarity index 100% rename from paddle/framework/images/duplicate_op.png rename to doc/design/images/duplicate_op.png diff --git a/paddle/framework/images/duplicate_op2.graffle b/doc/design/images/duplicate_op2.graffle similarity index 100% rename from paddle/framework/images/duplicate_op2.graffle rename to doc/design/images/duplicate_op2.graffle diff --git a/paddle/framework/images/duplicate_op2.png b/doc/design/images/duplicate_op2.png similarity index 100% rename from paddle/framework/images/duplicate_op2.png rename to doc/design/images/duplicate_op2.png diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index fa1b6a372728ccac128d2e6e79a6514b8884ea3f..bae42593ddc6f7a7eb47d603752ad6efa9820b45 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -15,7 +15,7 @@ 获取PaddlePaddle的Docker镜像 ------------------------------ -执行下面的命令获取最新的PaddlePaddle Docker镜像 +执行下面的命令获取最新的PaddlePaddle Docker镜像,版本为cpu_avx_mkl: .. code-block:: bash @@ -27,7 +27,7 @@ docker pull docker.paddlepaddle.org/paddle -下载GPU版本的Docker镜像: +下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像: .. code-block:: bash @@ -54,7 +54,7 @@ .. _docker_run: 在Docker中执行PaddlePaddle训练程序 ------------------------------- +---------------------------------- 假设您已经在当前目录(比如在/home/work)编写了一个PaddlePaddle的程序 :code:`train.py` (可以参考 `PaddlePaddleBook `_ @@ -82,7 +82,7 @@ .. _docker_run_book: 使用Docker启动PaddlePaddle Book教程 ------------------------------- +----------------------------------- 使用Docker可以快速在本地启动一个包含了PaddlePaddle官方Book教程的Jupyter Notebook,可以通过网页浏览。 PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Notebook。 diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 06012bf65e75c32957516f6b7f62e09480871b84..56a7c68e4d39c45249fa55a964dc48b7081596a6 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -16,7 +16,7 @@ After you've read above tutorials you may proceed the following steps. Pull PaddlePaddle Docker Image ------------------------------ -Run the following command to download the latest Docker images: +Run the following command to download the latest Docker images, the version is cpu_avx_mkl: .. code-block:: bash @@ -28,7 +28,7 @@ For users in China, we provide a faster mirror: docker pull docker.paddlepaddle.org/paddle -Download GPU version images: +Download GPU version (cuda8.0_cudnn5_avx_mkl) images: .. code-block:: bash @@ -58,7 +58,7 @@ and run: .. _docker_run: Launch your training program in Docker ------------------------------- +-------------------------------------- Assume that you have already written a PaddlePaddle program named :code:`train.py` under directory :code:`/home/work` (refer to diff --git a/doc/getstarted/build_and_install/pip_install_cn.rst b/doc/getstarted/build_and_install/pip_install_cn.rst index a4587f82a984acf243f49834e707fcd66d5b1252..0c741e936b46eda5e7165e4ee54b545b14a28a19 100644 --- a/doc/getstarted/build_and_install/pip_install_cn.rst +++ b/doc/getstarted/build_and_install/pip_install_cn.rst @@ -11,14 +11,14 @@ PaddlePaddle可以使用常用的Python包管理工具 ------------------------------ -执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件。 +执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件,版本为cpu_avx_openblas。 .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/build_and_install/pip_install_en.rst b/doc/getstarted/build_and_install/pip_install_en.rst index 55e31560a0f5087ab69966a6281c6c8573c04204..285ed09805b09790beaef014f6813c227aff33ac 100644 --- a/doc/getstarted/build_and_install/pip_install_en.rst +++ b/doc/getstarted/build_and_install/pip_install_en.rst @@ -12,14 +12,14 @@ Install Using pip ------------------------------ Run the following command to install PaddlePaddle on the current -machine, it will also download requirements. +machine, it will also download requirements, the version is cpu_avx_openblas. .. code-block:: bash pip install paddlepaddle -If you wish to install GPU version, just run: +If you wish to install GPU version (cuda7.5_cudnn5_avx_openblas), just run: .. code-block:: bash diff --git a/doc/getstarted/index_cn.rst b/doc/getstarted/index_cn.rst index a9087be6f350c5656cabb0c64ba0f200d1c666cc..9f6ee25987d51dcca3a37cf0f62a70a5a5a2d89a 100644 --- a/doc/getstarted/index_cn.rst +++ b/doc/getstarted/index_cn.rst @@ -7,13 +7,13 @@ ++++++++ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。 -执行下面的命令完成快速安装: +执行下面的命令完成快速安装,版本为cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/index_en.rst b/doc/getstarted/index_en.rst index d14e3f5c0cc90792fce9cb82e65da482c44dc433..063d9d880c82550f7f5d47d3d0b1fff59865bca7 100644 --- a/doc/getstarted/index_en.rst +++ b/doc/getstarted/index_en.rst @@ -8,13 +8,13 @@ Quick Install You can use pip to install PaddlePaddle with a single command, supports CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed. -Simply run the following command to install: +Simply run the following command to install, the version is cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -If you need to install GPU version, run: +If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: .. code-block:: bash diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 738684795d8170ffd5c5b2bf19e6e150219332d0..6788cb34fbaf5941cbb1537c7a83577c623bf76a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -5,10 +5,18 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) +if (WITH_GPU) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto) +else() + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto) +endif () cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) -cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +if (WITH_GPU) + nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) +else() + cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md deleted file mode 100644 index ac60be572419b62f4beb644ff192d413c35e19bb..0000000000000000000000000000000000000000 --- a/paddle/framework/backward.md +++ /dev/null @@ -1,100 +0,0 @@ -# Operator/expression 's Backward - -## Motivation - -In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass. - -## Implementation - -In this design doc, we exported only one API for generating the backward pass. - -```c++ -std::unique_ptr Backward(const OperatorBase& forwardOp, - const std::unordered_set& no_grad_vars); -``` - -The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**. - -### Backward Operator Registry - -A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients. - -| | forward operator | backward operator -| ---------------------- | ---------------- |------------------------- | -| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | -| **Operator::outputs_** | Outputs | InputGradients | - - In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced. - -For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro: - -```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); -``` - -`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. - -`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name. - -### Backward Opeartor Creating - -Given a certain forward operator, we can get its corresponding backward operator by calling: - -```cpp -OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op); -``` - -The function `BuildGradOp` will sequentially execute following processes: - -1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`. - -2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing. - -3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`. - -4. Building backward operator with `inputs`, `outputs` and forward operator's attributes. - -### Backward Network Building - -A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing. - -1. Op - - When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`. - -2. NetOp - - In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp. - -3. RnnOp - - RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet. - -4. Sharing Variables - - As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable. - -

-
- -​ Figure 1. Sharing variables in operators. - -

- -​ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting. - -

-
- -​ Figure 2. Replace sharing variable's gradient with `Add` operator. - -

- -​ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction. - -5. Part of the Gradient is Zero. - - In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator. - - -Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it. diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 2191dd3783d5ed7bb59b96c70d38a72bb0b2fee7..bd6d301c12e0611c5b01c3ff58869dbeb96b268e 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -27,9 +27,8 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFn = - std::function ctx, - const Variable& in, Variable* out)>; +using DataTransformFn = std::function; using KernelTypePair = std::pair; struct KernelTypePairHash { diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 4e2141ecd2ebe35402a8a04613702a2f79f6a179..5f05e881fa16eead1dc690f85375706bf3cd3e6d 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel3 = GenFromBit({0, 0, 1, 1}); -void TransDataType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value++; } -void TransDataLayout_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value--; } -void TransLibraryType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value += 2; } @@ -83,7 +83,8 @@ TEST(DataTransform, Register) { using namespace paddle::platform; auto& instance = DataTransformFnMap::Instance(); - std::vector ctx; + ASSERT_EQ(instance.Map().size(), 3UL); + DeviceContext* ctx = nullptr; paddle::framework::Variable in; paddle::framework::Variable out; diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 31749743a58835ee2209d5a448f28b011cb3f7af..bf1f0471ccbfccf13cb6f74c8088da7acd68ec0b 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -14,18 +14,17 @@ limitations under the License. */ #include "paddle/framework/executor.h" -#include -#include -#include #include -#include +#include "gflags/gflags.h" #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/lod_rank_table.h" -#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/scope.h" + +DEFINE_bool(check_nan_inf, false, + "Checking whether operator produce NAN/INF or not. It will be " + "extremely slow so please use this flag wisely."); namespace paddle { namespace framework { @@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { } } +static void CheckTensorNANOrInf(const std::string& name, + const framework::Tensor& tensor) { + if (tensor.memory_size() == 0) { + return; + } + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { + return; + } + PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); + PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); +} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { // TODO(tonyyang-svail): @@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugString(); op->Run(*local_scope, place_); + if (FLAGS_check_nan_inf) { + for (auto& vname : op->OutputVars(true)) { + auto* var = local_scope->FindVar(vname); + if (var == nullptr) continue; + if (var->IsType()) { + CheckTensorNANOrInf(vname, var->Get()); + } + } + } } if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index f8a3be9a82bdbaf82550634d36122eb7bbe85e54..7b6dc09bdb5535488c8c4dbc71c9cd6a7998bd0b 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { void SerializeToStream(std::ostream &os, const LoDTensor &tensor, const platform::DeviceContext &dev_ctx) { - // TODO(typhoonzero): serialize to ostream - { // the 1st field, uint32_t version + { // the 1st field, uint32_t version for LoDTensor constexpr uint32_t version = 0; os.write(reinterpret_cast(&version), sizeof(version)); } - { // the 2nd field, tensor description - // int32_t size - // void* protobuf message - proto::TensorDesc desc; - desc.set_data_type(framework::ToDataType(tensor.type())); - auto dims = framework::vectorize(tensor.dims()); - auto *pb_dims = desc.mutable_dims(); - pb_dims->Resize(static_cast(dims.size()), 0); - std::copy(dims.begin(), dims.end(), pb_dims->begin()); - int32_t size = desc.ByteSize(); - os.write(reinterpret_cast(&size), sizeof(size)); - auto out = desc.SerializeAsString(); - os.write(out.data(), size); - } - { // the 3rd field, tensor data - uint64_t size = tensor.memory_size(); - auto *data_ptr = tensor.data(); - PADDLE_ENFORCE(size < std::numeric_limits::max(), - "Index overflow when writing tensor"); - if (platform::is_gpu_place(tensor.place())) { -#ifdef PADDLE_WITH_CUDA - constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB - std::unique_ptr buf(new char[kBufSize]); - auto &gpu_dev_ctx = - static_cast(dev_ctx); - platform::CPUPlace cpu; - uintptr_t data = reinterpret_cast(data_ptr); - while (size != 0) { - size_t size_to_write = std::min(kBufSize, static_cast(size)); - memory::Copy(cpu, buf.get(), - boost::get(tensor.place()), - reinterpret_cast(data), size_to_write, - gpu_dev_ctx.stream()); - gpu_dev_ctx.Wait(); - os.write(buf.get(), size_to_write); - data += size_to_write; - size -= size_to_write; - } -#else - PADDLE_THROW("Unexpected branch"); -#endif - } else { - os.write(static_cast(data_ptr), - static_cast(size)); - } - } - { // the 4th field, lod information - // uint64_t lod_level - // uint64_t lod_level_1 size in byte. - // int* lod_level_1 data - // ... + { + // the 2st field, LoD information + // uint64_t lod_level + // uint64_t lod_level_1 size in byte. + // int* lod_level_1 data + // ... auto lod = tensor.lod(); uint64_t size = lod.size(); os.write(reinterpret_cast(&size), sizeof(size)); @@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, static_cast(size)); } } + // the 3st field, Tensor + SerializeToStream(os, static_cast(tensor), dev_ctx); } void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { - uint32_t version; - is.read(reinterpret_cast(&version), sizeof(version)); - PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); - proto::TensorDesc desc; - { // int32_t size - // proto buffer - int32_t size; - is.read(reinterpret_cast(&size), sizeof(size)); - std::unique_ptr buf(new char[size]); - is.read(reinterpret_cast(buf.get()), size); - PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), - "Cannot parse tensor desc"); - } - { // read tensor - std::vector dims; - dims.reserve(static_cast(desc.dims().size())); - std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); - tensor->Resize(framework::make_ddim(dims)); - - void *buf; - platform::Place cpu = platform::CPUPlace(); - switch (desc.data_type()) { - case proto::FP32: - buf = tensor->mutable_data(cpu); - break; - case proto::FP64: - buf = tensor->mutable_data(cpu); - break; - case proto::INT32: - buf = tensor->mutable_data(cpu); - break; - case proto::INT64: - buf = tensor->mutable_data(cpu); - break; - default: - PADDLE_THROW("DataType %d not supported", desc.data_type()); - } - is.read(static_cast(buf), tensor->memory_size()); - } - { // read lod + { + // the 1st field, unit32_t version for SelectedRows + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + } + { + // the 2st field, LoD information uint64_t lod_level; is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); auto &lod = *tensor->mutable_lod(); @@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { lod[i] = tmp; } } + // the 3st filed, Tensor + DeserializeFromStream(is, static_cast(tensor)); } } // namespace framework diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 02d84b68233f2fdfc66e1df2fc7ce20307cadd94..0747c8db531d6ae443d76591b945cce0c9bbea2b 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { EXPECT_NE(t1.data(), lod_tensor_.data()); } +TEST_F(LoDTensorTester, SerializeAndDeserialize) { + LoDTensor dst_tensor; + platform::CPUDeviceContext cpu_ctx((platform::CPUPlace())); + std::ostringstream oss; + SerializeToStream(oss, lod_tensor_, cpu_ctx); + std::istringstream iss(oss.str()); + DeserializeFromStream(iss, &dst_tensor); + float* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < kLodTensorSize; ++i) { + EXPECT_EQ(dst_ptr[i], i); + } + EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod()); +} + TEST(LodExpand, test) { LoD lod{{0, 2}}; LoDTensor tensor; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c0be11294c4a6b49ae4bc2f805f76e9f04508349..a3ce96c409675ad52a811586c736ca22b5c7e99e 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; +const platform::DeviceContext* GetDeviceContext( + framework::KernelTypePair& kernel_pair) { + auto& actual_kernel_key = kernel_pair.first; + auto& expected_kernel_key = kernel_pair.second; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + + if (platform::is_gpu_place(actual_kernel_key.place_) && + platform::is_cpu_place(expected_kernel_key.place_)) { + return pool.Get(actual_kernel_key.place_); + } else if (platform::is_cpu_place(actual_kernel_key.place_) && + platform::is_gpu_place(expected_kernel_key.place_)) { + return pool.Get(expected_kernel_key.place_); + } else { + PADDLE_THROW( + "Currently, model parallelism is only supported between CPU and CUDA"); + } +} + void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); @@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope, "CPU and other devices. For example, multi-GPU model " "parallelism will failed."); } else { + auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); const DataTransformFn* trans_fun = - DataTransformFnMap::Instance().GetNullable( - std::make_pair(actual_kernel_key, expected_kernel_key)); + DataTransformFnMap::Instance().GetNullable(kernel_pair); if (trans_fun) { auto input_vars = this->InputVars(); // TODO(qijun) filter the input vars that do not need to be transformed @@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope, } if (!need_trans.empty()) { - // TODO(qijun) get appropriate DeviceContext from DeviceContext pool - platform::DeviceContext* trans_dev_ctx = nullptr; - std::vector trans_dev_ctx_vec{trans_dev_ctx}; + auto trans_dev_ctx = GetDeviceContext(kernel_pair); // Wait for transform starting dev_ctx->Wait(); for (auto var_name : need_trans) { - (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)), scope.FindVar(var_name + framework::KernelTypeToString( expected_kernel_key))); } // Wait for data transform finishing - for (auto ctx : trans_dev_ctx_vec) { - ctx->Wait(); - } + trans_dev_ctx->Wait(); } } } diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc index c74459c9dd7006a24615b1d6df041583088fb25c..82adfa7123a3cf40d929021602c45fe7d2e34ffa 100644 --- a/paddle/framework/selected_rows.cc +++ b/paddle/framework/selected_rows.cc @@ -12,5 +12,58 @@ limitations under the License. */ #include "paddle/framework/selected_rows.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { +void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, + const platform::DeviceContext& dev_ctx) { + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + { + // the 2st field, rows information + auto& rows = selected_rows.rows(); + uint64_t size = rows.size(); + os.write(reinterpret_cast(&size), sizeof(size)); + for (uint64_t i = 0; i < size; ++i) { + os.write(reinterpret_cast(&rows[i]), sizeof(rows[i])); + } + } + { + // the 3st field, the height of SelectedRows + int64_t height = selected_rows.height(); + os.write(reinterpret_cast(&height), sizeof(height)); + } + // the 4st field, Tensor data + SerializeToStream(os, selected_rows.value(), dev_ctx); +} + +void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) { + auto tensor = *selected_rows->mutable_value(); + { + // the 1st field, unit32_t version for SelectedRows + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + } + { + // the 2st field, rows information + uint64_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + auto& rows = *selected_rows->mutable_rows(); + rows.resize(size); + for (uint64_t i = 0; i < size; ++i) { + is.read(reinterpret_cast(&rows[i]), sizeof(int64_t)); + } + } + { + // the 3st field, the height of the SelectedRows + int64_t height; + is.read(reinterpret_cast(&height), sizeof(int64_t)); + selected_rows->set_height(height); + } + // the 4st field, tensor which contains the data + DeserializeFromStream(is, &tensor); +} + +} // namespace framework } // namespace paddle diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 0332b91323e3a4b4b80e02302ad3dcafe0986cde..699e392688e9889f050592172f8bfc45f855d0b1 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -59,5 +59,14 @@ class SelectedRows { int64_t height_; }; +/* + * Serialize/Desiralize SelectedRows to std::ostream + * You can pass ofstream or ostringstream to serilize to file + * or to a in memory string. GPU tensor will be copied to CPU. + */ +void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, + const platform::DeviceContext& dev_ctx); +void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc index 4ee13a65d72e44693573397bb686b355effb2227..75487c4010391aa9e519d73058184fa936dabb84 100644 --- a/paddle/framework/selected_rows_test.cc +++ b/paddle/framework/selected_rows_test.cc @@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) { ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100})); } +TEST_F(SelectedRowsTester, SerializeAndDeseralize) { + SelectedRows dst_tensor; + platform::CPUDeviceContext cpu_ctx(place_); + std::ostringstream oss; + + SerializeToStream(oss, *selected_rows_, cpu_ctx); + + std::istringstream iss(oss.str()); + DeserializeFromStream(iss, &dst_tensor); + + ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows()); + ASSERT_EQ(selected_rows_->height(), dst_tensor.height()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index ca76a9fcb9079bab22f7b192c45903852c91797f..a1b4a03289eca4c8b9d8c23ede4221853cb31f79 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -15,12 +15,13 @@ #include #include +namespace framework = paddle::framework; +namespace platform = paddle::platform; + TEST(Tensor, Dims) { - using namespace paddle::framework; - using namespace paddle::platform; - Tensor tt; + framework::Tensor tt; tt.Resize({2, 3, 4}); - DDim dims = tt.dims(); + framework::DDim dims = tt.dims(); ASSERT_EQ(arity(dims), 3); for (int i = 0; i < 3; ++i) { EXPECT_EQ(i + 2, dims[i]); @@ -28,12 +29,12 @@ TEST(Tensor, Dims) { } TEST(Tensor, DataAssert) { - paddle::framework::Tensor src_tensor; + framework::Tensor src_tensor; bool caught = false; try { src_tensor.data(); - } catch (paddle::platform::EnforceNotMet err) { + } catch (platform::EnforceNotMet err) { caught = true; std::string msg = "holder_ should not be null\nTensor holds no memory. Call " @@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) { because Memory::Alloc() and Memory::Free() have not been ready. */ TEST(Tensor, MutableData) { - using namespace paddle::framework; - using namespace paddle::platform; { - Tensor src_tensor; + framework::Tensor src_tensor; float* p1 = nullptr; float* p2 = nullptr; // initialization - p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}), + platform::CPUPlace()); EXPECT_NE(p1, nullptr); // set src_tensor a new dim with large size // momery is supposed to be re-allocated - p2 = src_tensor.mutable_data(make_ddim({3, 4}), CPUPlace()); + p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}), + platform::CPUPlace()); EXPECT_NE(p2, nullptr); EXPECT_NE(p1, p2); // set src_tensor a new dim with same size // momery block is supposed to be unchanged - p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CPUPlace()); + p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), + platform::CPUPlace()); EXPECT_EQ(p1, p2); // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged - p2 = src_tensor.mutable_data(make_ddim({2, 2}), CPUPlace()); + p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), + platform::CPUPlace()); EXPECT_EQ(p1, p2); } #ifdef PADDLE_WITH_CUDA { - Tensor src_tensor; + framework::Tensor src_tensor; float* p1 = nullptr; float* p2 = nullptr; // initialization - p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CUDAPlace()); + p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}), + platform::CUDAPlace()); EXPECT_NE(p1, nullptr); // set src_tensor a new dim with large size // momery is supposed to be re-allocated - p2 = src_tensor.mutable_data(make_ddim({3, 4}), CUDAPlace()); + p2 = src_tensor.mutable_data(framework::make_ddim({3, 4}), + platform::CUDAPlace()); EXPECT_NE(p2, nullptr); EXPECT_NE(p1, p2); // set src_tensor a new dim with same size // momery block is supposed to be unchanged - p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CUDAPlace()); + p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), + platform::CUDAPlace()); EXPECT_EQ(p1, p2); // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged - p2 = src_tensor.mutable_data(make_ddim({2, 2}), CUDAPlace()); + p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), + platform::CUDAPlace()); EXPECT_EQ(p1, p2); } #endif } TEST(Tensor, ShareDataWith) { - using namespace paddle::framework; - using namespace paddle::platform; { - Tensor src_tensor; - Tensor dst_tensor; + framework::Tensor src_tensor; + framework::Tensor dst_tensor; // Try to share data form uninitialized tensor bool caught = false; try { @@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) { } ASSERT_TRUE(caught); - src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace()); + src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), + platform::CPUPlace()); dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } #ifdef PADDLE_WITH_CUDA { - Tensor src_tensor; - Tensor dst_tensor; - src_tensor.mutable_data(make_ddim({2, 3, 4}), CUDAPlace()); + framework::Tensor src_tensor; + framework::Tensor dst_tensor; + src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), + platform::CUDAPlace()); dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) { } TEST(Tensor, Slice) { - using namespace paddle::framework; - using namespace paddle::platform; { - Tensor src_tensor; - src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace()); - Tensor slice_tensor = src_tensor.Slice(1, 3); - DDim slice_dims = slice_tensor.dims(); + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({5, 3, 4}), + platform::CPUPlace()); + framework::Tensor slice_tensor = src_tensor.Slice(1, 3); + framework::DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 3); EXPECT_EQ(slice_dims[0], 2); EXPECT_EQ(slice_dims[1], 3); @@ -153,11 +159,12 @@ TEST(Tensor, Slice) { uintptr_t src_data_address = reinterpret_cast(src_tensor.data()); uintptr_t src_mutable_data_address = reinterpret_cast( - src_tensor.mutable_data(src_tensor.dims(), CPUPlace())); + src_tensor.mutable_data(src_tensor.dims(), platform::CPUPlace())); uintptr_t slice_data_address = reinterpret_cast(slice_tensor.data()); - uintptr_t slice_mutable_data_address = reinterpret_cast( - slice_tensor.mutable_data(slice_tensor.dims(), CPUPlace())); + uintptr_t slice_mutable_data_address = + reinterpret_cast(slice_tensor.mutable_data( + slice_tensor.dims(), platform::CPUPlace())); EXPECT_EQ(src_data_address, src_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address); @@ -165,22 +172,25 @@ TEST(Tensor, Slice) { #ifdef PADDLE_WITH_CUDA { - Tensor src_tensor; - src_tensor.mutable_data(make_ddim({6, 9}), CUDAPlace()); - Tensor slice_tensor = src_tensor.Slice(2, 6); - DDim slice_dims = slice_tensor.dims(); + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 9}), + platform::CUDAPlace()); + framework::Tensor slice_tensor = src_tensor.Slice(2, 6); + framework::DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 2); EXPECT_EQ(slice_dims[0], 4); EXPECT_EQ(slice_dims[1], 9); uintptr_t src_data_address = reinterpret_cast(src_tensor.data()); - uintptr_t src_mutable_data_address = reinterpret_cast( - src_tensor.mutable_data(src_tensor.dims(), CUDAPlace())); + uintptr_t src_mutable_data_address = + reinterpret_cast(src_tensor.mutable_data( + src_tensor.dims(), platform::CUDAPlace())); uintptr_t slice_data_address = reinterpret_cast(slice_tensor.data()); - uintptr_t slice_mutable_data_address = reinterpret_cast( - slice_tensor.mutable_data(slice_tensor.dims(), CUDAPlace())); + uintptr_t slice_mutable_data_address = + reinterpret_cast(slice_tensor.mutable_data( + slice_tensor.dims(), platform::CUDAPlace())); EXPECT_EQ(src_data_address, src_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); @@ -189,23 +199,19 @@ TEST(Tensor, Slice) { } TEST(Tensor, ReshapeToMatrix) { - using namespace paddle::framework; - using namespace paddle::platform; - Tensor src; - int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace()); + framework::Tensor src; + int* src_ptr = src.mutable_data({2, 3, 4, 9}, platform::CPUPlace()); for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { src_ptr[i] = i; } - Tensor res = ReshapeToMatrix(src, 2); + framework::Tensor res = framework::ReshapeToMatrix(src, 2); ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[1], 4 * 9); } TEST(Tensor, Layout) { - using namespace paddle::framework; - using namespace paddle::platform; - Tensor src; - ASSERT_EQ(src.layout(), DataLayout::kNHWC); - src.set_layout(DataLayout::kAnyLayout); - ASSERT_EQ(src.layout(), DataLayout::kAnyLayout); + framework::Tensor src; + ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC); + src.set_layout(framework::DataLayout::kAnyLayout); + ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout); } diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..7efc649d0bcda67c663d148e83bcbb6789b0f371 --- /dev/null +++ b/paddle/framework/tensor_util.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/tensor_util.h" + +namespace paddle { +namespace framework { +template +struct AnyDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void operator()() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenScalar::From(*out_); + // return any of predicate_(t) is true. + o.device(*ctx_.eigen_device()) = predicate_(t).any(); + } +}; + +template +inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +struct AnyVisitor : public boost::static_visitor { + const framework::Tensor& tensor_; + Predicate predicate_; + + AnyVisitor(const framework::Tensor& tensor, Predicate predicate) + : tensor_(tensor), predicate_(std::move(predicate)) {} + + template + bool operator()(const Place& place) const { + framework::Tensor out; + out.Resize({1}); + out.mutable_data(place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + AnyImpl(predicate_, tensor_, *ctx, &out); + return this->GetResult(out, place); + } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPlace& gpu) const { + platform::CPUPlace cpu; + framework::Tensor tmp; + tmp.Resize({1}); + tmp.mutable_data(cpu); + auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); + gpuctx->Wait(); + CopyFrom(out, cpu, *gpuctx, &tmp); + gpuctx->Wait(); + return GetResult(tmp, cpu); + } + + bool GetResult(const framework::Tensor& out, + const platform::CPUPlace& cpu) const { + return *out.data(); + } +}; + +template +inline bool Any(const framework::Tensor& tensor, Predicate predicate) { + AnyVisitor visitor(tensor, predicate); + auto place = tensor.place(); + return platform::VisitPlace(place, visitor); +} + +struct HasNANPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isnan()) { + // Cast eigen_vector to vector of bool. true if is inf. + return eigen_vec.isnan(); + } +}; + +bool HasNAN(const framework::Tensor& tensor) { + HasNANPredicate predicate; + return Any(tensor, predicate); +} + +struct HasInfPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isinf()) { + // Cast eigen_vector to vector of bool. true if is inf. + return eigen_vec.isinf(); + } +}; + +bool HasInf(const framework::Tensor& tensor) { + HasInfPredicate predicate; + return Any(tensor, predicate); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor_util.cu b/paddle/framework/tensor_util.cu new file mode 120000 index 0000000000000000000000000000000000000000..b00e6e59d93328bf3142597ea4de0dc225501e56 --- /dev/null +++ b/paddle/framework/tensor_util.cu @@ -0,0 +1 @@ +./tensor_util.cc \ No newline at end of file diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index ea4e4f22ea82ccc9f8b683d2fd773a7bc37f78a3..6a21f8db1e3966fd23eee0da2346b2d61f9321fb 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/data_type.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/framework.pb.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -205,5 +209,109 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } +// Returns true if a tensor contains NAN, i.e., Not A Number. +bool HasNAN(const framework::Tensor& tensor); + +// Returns true if a tensor contains Inf, i.e., Infinity. +bool HasInf(const framework::Tensor& tensor); + +inline void SerializeToStream(std::ostream& os, const Tensor& tensor, + const platform::DeviceContext& dev_ctx) { + // TODO(typhoonzero): serialize to ostream + { // the 1st field, uint32_t version + constexpr uint32_t version = 0; + os.write(reinterpret_cast(&version), sizeof(version)); + } + { // the 2nd field, tensor description + // int32_t size + // void* protobuf message + proto::TensorDesc desc; + desc.set_data_type(framework::ToDataType(tensor.type())); + auto dims = framework::vectorize(tensor.dims()); + auto* pb_dims = desc.mutable_dims(); + pb_dims->Resize(static_cast(dims.size()), 0); + std::copy(dims.begin(), dims.end(), pb_dims->begin()); + int32_t size = desc.ByteSize(); + os.write(reinterpret_cast(&size), sizeof(size)); + auto out = desc.SerializeAsString(); + os.write(out.data(), size); + } + { // the 3rd field, tensor data + uint64_t size = tensor.memory_size(); + auto* data_ptr = tensor.data(); + PADDLE_ENFORCE(size < std::numeric_limits::max(), + "Index overflow when writing tensor"); + if (platform::is_gpu_place(tensor.place())) { +#ifdef PADDLE_WITH_CUDA + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto& gpu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + boost::get(tensor.place()), + reinterpret_cast(data), size_to_write, + gpu_dev_ctx.stream()); + gpu_dev_ctx.Wait(); + os.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + os.write(static_cast(data_ptr), + static_cast(size)); + } + } +} + +inline void DeserializeFromStream(std::istream& is, Tensor* tensor) { + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); + proto::TensorDesc desc; + { // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), + "Cannot parse tensor desc"); + } + { // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); + tensor->Resize(framework::make_ddim(dims)); + + void* buf; + platform::Place cpu = platform::CPUPlace(); + // TODO(Yancey1989): use VisiterDataType instead of DataType switch + switch (desc.data_type()) { + case proto::FP32: + buf = tensor->mutable_data(cpu); + break; + case proto::FP64: + buf = tensor->mutable_data(cpu); + break; + case proto::INT32: + buf = tensor->mutable_data(cpu); + break; + case proto::INT64: + buf = tensor->mutable_data(cpu); + break; + default: + PADDLE_THROW("DataType %d not supported", desc.data_type()); + } + is.read(static_cast(buf), tensor->memory_size()); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index f388c19f28ed28335818733f946d8eaf18464627..0dc5166fcabf77b48b8681ab1f050e2bc88f44ab 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -13,6 +13,7 @@ #include "paddle/framework/tensor_util.h" #include +#include #include namespace paddle { @@ -230,5 +231,78 @@ TEST(CopyToVector, Tensor) { #endif } +TEST(HasNAN, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + float* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1] = NAN; + buf[2] = 0.0; + + ASSERT_TRUE(HasNAN(src)); +} + +TEST(HasInf, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + double* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + ASSERT_TRUE(HasInf(src)); +} + +TEST(Tensor, SerializeAndDeserialize) { + framework::Tensor src_tensor; + int array[6] = {1, 2, 3, 4, 5, 6}; + src_tensor.Resize({2, 3}); + int* src_ptr = src_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < 6; ++i) { + src_ptr[i] = array[i]; + } + { + framework::Tensor dst_tensor; + auto place = new platform::CPUPlace(); + platform::CPUDeviceContext cpu_ctx(*place); + std::ostringstream oss; + SerializeToStream(oss, src_tensor, cpu_ctx); + + std::istringstream iss(oss.str()); + DeserializeFromStream(iss, &dst_tensor); + int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < 5; ++i) { + ASSERT_EQ(dst_ptr[i], array[i]); + } + delete place; + } +#ifdef PADDLE_WITH_CUDA + { + Tensor gpu_tensor; + gpu_tensor.Resize({2, 3}); + Tensor dst_tensor; + + auto gpu_place = new platform::CUDAPlace(); + platform::CUDADeviceContext gpu_ctx(*gpu_place); + + CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); + + std::ostringstream oss; + SerializeToStream(oss, gpu_tensor, gpu_ctx); + + std::istringstream iss(oss.str()); + DeserializeFromStream(iss, &dst_tensor); + + int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(dst_ptr[i], array[i]); + } + + delete gpu_place; + } +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebd35fdf6c2a1388fec23057070f723c8ef9da9c --- /dev/null +++ b/paddle/framework/tensor_util_test.cu @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/framework/tensor_util.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace framework { + +static __global__ void FillNAN(float* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2] = NAN; +} +static __global__ void FillInf(float* buf) { + buf[0] = 0.0; + buf[1] = INFINITY; + buf[2] = 0.5; +} + +TEST(HasNAN, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasNAN(tensor)); +} + +TEST(HasInf, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasInf(tensor)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h index 5f6b2d458f7ee764c22d203f285b78023b6012f3..bcd8190755083ec30687675602a1c95a9c15c69e 100644 --- a/paddle/framework/threadpool.h +++ b/paddle/framework/threadpool.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -25,10 +26,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef std::function Task; - class ThreadPool { public: + typedef std::packaged_task Task; + typedef std::function Fun; + /** * @brief Get a instance of threadpool, the thread number will * be specified as the number of hardware thread contexts @@ -61,13 +63,18 @@ class ThreadPool { /** * @brief Push a function to the queue, and will be scheduled and * executed if a thread is available. - * @param[in] Task will be pushed to the task queue. + * @param[in] Task, will be pushed to the task queue. + * @return std::future, we could wait for the task finished by + * f.wait(). */ - void Run(const Task& fn) { + std::future Run(const Fun& fn) { std::unique_lock lock(mutex_); - tasks_.push(fn); + Task task(std::bind(fn)); + std::future f = task.get_future(); + tasks_.push(std::move(task)); lock.unlock(); scheduled_.notify_one(); + return f; } /** @@ -110,7 +117,7 @@ class ThreadPool { break; } // pop a task from the task queue - auto task = tasks_.front(); + auto task = std::move(tasks_.front()); tasks_.pop(); --available_; diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc index 012d92a5edc415f0bb2f8a0ea38ffeb9549d54fa..50b6238cd8786be9d8cf2d5f821daadea12bd208 100644 --- a/paddle/framework/threadpool_test.cc +++ b/paddle/framework/threadpool_test.cc @@ -20,16 +20,21 @@ limitations under the License. */ namespace framework = paddle::framework; void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) { + std::vector> fs; for (int i = 0; i < cnt; ++i) { - pool->Run([&sum]() { sum.fetch_add(1); }); + auto f = pool->Run([&sum]() { sum.fetch_add(1); }); + fs.push_back(std::move(f)); + } + for (auto& f : fs) { + f.wait(); } } TEST(ThreadPool, ConcurrentInit) { framework::ThreadPool* pool; - int concurrent_cnt = 50; + int n = 50; std::vector threads; - for (int i = 0; i < concurrent_cnt; ++i) { + for (int i = 0; i < n; ++i) { std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); }); threads.push_back(std::move(t)); } @@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) { } } -TEST(ThreadPool, ConcurrentStart) { +TEST(ThreadPool, ConcurrentRun) { framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; - int concurrent_cnt = 50; + int n = 50; // sum = (n * (n + 1)) / 2 - for (int i = 1; i <= concurrent_cnt; ++i) { + for (int i = 1; i <= n; ++i) { std::thread t(do_sum, pool, std::ref(sum), i); threads.push_back(std::move(t)); } @@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) { t.join(); } pool->Wait(); - EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2); + EXPECT_EQ(sum, ((n + 1) * n) / 2); } diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index de7b70e271b38ebe3a4c38704d0cced47d010788..cbdbf5335d32d55a0221728758025c9d2cb3e7d1 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -126,14 +126,165 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } + } +}; + #ifdef PADDLE_MOBILE_INFERENCE - if (Device == DEVICE_TYPE_CPU) { - memory_.reset(); + +/* + * \brief Forward calculation of convolution, optimized for mobile. + */ +template +class GemmConvMobileFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; } -#endif + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + + TensorShape colShape; + real* colData = NULL; + + size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; + size_t colWidth = outputHeight * outputWidth; + // Max col matrix height 256, Max col matrix width 1024 + size_t stepColHeight = std::min(colHeight, static_cast(256)); + size_t stepColWidth = std::min(colWidth, static_cast(2048)); + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(stepColHeight * stepColWidth * sizeof(real)); + colData = reinterpret_cast(memory_->getBuf()); + } + + Im2ColMobileFunctor im2col; + size_t inputOffset = imShape.getElements(); + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + int nStride = colWidth; + int kStride = colHeight; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + if (needIm2col) { + real beta_ = beta; + for (size_t colHeightStart = 0; colHeightStart < colHeight; + colHeightStart += stepColHeight) { + for (size_t colWidthStart = 0; colWidthStart < colWidth; + colWidthStart += stepColWidth) { + int N = std::min(colWidth - colWidthStart, stepColWidth); + int K = std::min(colHeight - colHeightStart, stepColHeight); + // im2col + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW(), + dilationH(), + dilationW(), + colHeightStart, + K, + colWidthStart, + N); + + // gemm + int M = outputChannels / groups_; + BlasGemm::compute( + false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); + } + beta_ = 1.0; + } + } else { + int M = outputChannels / groups_; + int N = outputHeight * outputWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); + } + } + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + + memory_.reset(); } }; +#endif + /* * \brief Backward input calculation of convolution. */ @@ -348,7 +499,11 @@ public: } }; +#ifdef PADDLE_MOBILE_INFERENCE +REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction); +#else REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +#endif REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 0c37fc972484bfbede01d23652e384071bf883af..36a9bcf84e4b14965c83627821b71d1c7c0da1b2 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,4 +98,54 @@ public: int dilationWidth = 1); }; +template +class Im2ColMobileFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int dilationHeight, + int dilationWidth, + int colHeightStart, + int colHeightSize, + int colWidthStart, + int colWidthSize) { + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputWidth = colShape[4]; + + for (int colh = 0; colh < colHeightSize; colh++) { + int wOffset = (colHeightStart + colh) % filterWidth; + int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; + int c_im = (colHeightStart + colh) / filterWidth / filterHeight; + + for (int colw = 0; colw < colWidthSize; colw++) { + int h = (colWidthStart + colw) / outputWidth; + int w = (colWidthStart + colw) % outputWidth; + + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[colh * colWidthSize + colw] = static_cast(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[colh * colWidthSize + colw] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } +}; + } // namespace paddle diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 1f085538d81904dbd5b5d6bcd014adaed22e37d7..3ba866dcdd845403d52f7a85adfef08cbb11c305 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } #endif +template +void TestIm2ColMobileFunctor() { + for (size_t channels : {32}) { + for (size_t inputHeight : {33, 100}) { + for (size_t inputWidth : {32, 96}) { + for (size_t filterHeight : {5}) { + for (size_t filterWidth : {7}) { + for (size_t stride : {2}) { + for (size_t padding : {1}) { + for (size_t dilation : {1, 3}) { + size_t filterSizeH = (filterHeight - 1) * dilation + 1; + size_t filterSizeW = (filterWidth - 1) * dilation + 1; + if (inputHeight + 2 * padding < filterSizeH || + inputWidth + 2 * padding < filterSizeW) + break; + if (padding >= filterSizeH || padding >= filterSizeW) break; + size_t outputHeight = + (inputHeight - filterSizeH + 2 * padding) / stride + 1; + size_t outputWidth = + (inputWidth - filterSizeW + 2 * padding) / stride + 1; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = + Vector::create(imShape.getElements(), false); + VectorPtr input2 = + Vector::create(imShape.getElements(), false); + MatrixPtr output1 = + Matrix::create(height, width, false, false); + MatrixPtr output2 = + Matrix::create(height, width, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColMobileFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation, + 0, + height, + 0, + width); + + autotest::TensorCheckEqual(*output1, *output2); + } + } + } + } + } + } + } + } +} + +TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor(); } + } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 3e686b1c415e61a24e0f6729555e672721cf806f..bfcc70b31defd378fba0e505b1522471ca285ac4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -53,7 +53,6 @@ function(op_library TARGET) if (${op_library_DEPS_len} GREATER 0) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() - if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) @@ -187,6 +186,36 @@ endfunction() add_subdirectory(math) add_subdirectory(nccl) +set(DEPS_OPS + cond_op + cross_entropy_op + recurrent_op + softmax_with_cross_entropy_op + softmax_op + sequence_softmax_op + sum_op + pool_op + maxout_op + unpool_op + pool_with_index_op + conv_op + conv_transpose_op + nccl_op + sequence_conv_op + sequence_pool_op + lod_rank_table_op + lod_tensor_to_array_op + array_to_lod_tensor_op + max_sequence_len_op + lstm_op + gru_op + adagrad_op + sgd_op + save_op + load_op + send_op + recv_op + detection_output_op) if(WITH_GPU) op_library(nccl_op DEPS nccl_common) else() @@ -210,6 +239,7 @@ op_library(cond_op DEPS framework_proto tensor net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) +op_library(detection_output_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) @@ -229,6 +259,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) +op_library(cos_sim_op DEPS cos_sim_functor) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 052c793a01907abdc7784d1290f43543ae81bdb1..c83318a272302a474c37ce86365201acf56b9cad 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -105,48 +105,18 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - for (size_t i = 0; i < grad_rows.size(); i++) { - size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]); - for (int64_t j = 0; j < grad_width; j++) { - grad_merge_data[grad_merge_i * grad_width + j] += - grad_data[i * grad_width + j]; - } - } + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto& merge_rows = grad_merge.rows(); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 75bc7affd6c78beb783e01682b4538f2c259df26..4e579387924a5b0499f29609bc6b1322030a3c0d 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -78,62 +78,30 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid1(1, grad_rows.size()); - - MergeGradKernel< - T, 256><<(context) - .stream()>>>(grad_data, grad.rows().data(), - grad_merge_data, grad_merge->rows().data(), - grad_merge->rows().size(), grad_width); - + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + auto& merge_rows = grad_merge.rows(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); auto* param_data = param->data(); auto* moment_data = moment->data(); + const int block_size = 256; + dim3 threads(block_size, 1); dim3 grid2(1, merge_rows.size()); SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, grad_merge->rows().data(), + .stream()>>>(grad_merge_data, grad_merge.rows().data(), lr, param_data, moment_data, grad_width, epsilon); } diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index c4e2c8bb88ec9c74bd782570c10fb217178c8e48..9cc34bdded780e61e8700eb4fa4a295c84fb48bc 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -16,11 +16,14 @@ limitations under the License. */ #include // for sqrt in CPU and CUDA #include "paddle/framework/op_registry.h" #include "paddle/operators/detail/safe_ref.h" +#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/for_range.h" namespace paddle { namespace operators { +namespace scatter = paddle::operators::math::scatter; + template struct AdamFunctor { T beta1_; @@ -79,6 +82,69 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel) {} + + inline HOSTDEVICE void operator()(size_t i) const { + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + for (int64_t j = 0; j < row_numel_; ++j) { + T g = grad_[i * row_numel_ + j]; + T mom1 = moment1_[rows_[i] * row_numel_ + j]; + T mom2 = moment2_[rows_[i] * row_numel_ + j]; + T lr = *lr_; + T p = param_[rows_[i] * row_numel_ + j]; + + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + moment1_out_[rows_[i] * row_numel_ + j] = mom1; + moment2_out_[rows_[i] * row_numel_ + j] = mom2; + param_out_[rows_[i] * row_numel_ + j] = p; + } // for col id + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel { T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); auto& param = Ref(ctx.Input("Param"), "Must set Param"); - auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + // auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + auto* grad_var = ctx.InputVar("Grad"); auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1"); auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2"); auto& lr = @@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel { auto& mom2_out = Ref(ctx.Output("Moment2Out"), "Must set Moment1Out"); - AdamFunctor functor(beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad.template data(), - param.template data(), - param_out.template mutable_data(ctx.GetPlace())); - platform::ForRange for_range( - static_cast(ctx.device_context()), param.numel()); - for_range(functor); + if (grad_var->IsType()) { + auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + AdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad.template data(), + param.template data(), + param_out.template mutable_data(ctx.GetPlace())); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param.numel()); + for_range(functor); + } else if (grad_var->IsType()) { + auto& grad = + Ref(ctx.Input("Grad"), "Must set Grad"); + // merge duplicated rows if any. + scatter::MergeAdd merge_func; + auto grad_merge = + merge_func(ctx.template device_context(), grad); + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + auto* rows = grad_merge.rows().data(); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); + platform::ForRange for_range( + static_cast(ctx.device_context()), + grad_merge.rows().size()); + for_range(functor); + } else { + PADDLE_THROW("Variable type not supported by adam_op"); + } } }; diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index ab52a41b539236f1691ce8bc02d31e336ee4ccbb..e65a5dce52c3c51d3d6bee1684c1e97230203d38 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -31,8 +31,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); std::vector dilations = ctx->Attrs().Get>("dilations"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, "Conv intput should be 4-D or 5-D tensor."); @@ -45,9 +43,13 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ( paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); + + int input_channels = in_dims[1]; PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); + + int output_channels = filter_dims[0]; PADDLE_ENFORCE_EQ( output_channels % groups, 0, "The number of output channels should be divided by groups."); diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index e2b6282c0913e8ad16f8e3f6c3054f9567822d15..eadcca55f9bfc3e59f329df8ff419ad4c5a29007 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -13,19 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; -template -using EigenVector = framework::EigenVector; template class CosSimKernel : public framework::OpKernel { @@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel { out_x_norm->mutable_data(context.GetPlace()); out_y_norm->mutable_data(context.GetPlace()); - // convert Tensor to Eigen Tensor int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenVector::Flatten(*out_z); - auto x_norm = EigenVector::Flatten(*out_x_norm); - auto y_norm = EigenVector::Flatten(*out_y_norm); - // compute - auto& place = - *context.template device_context().eigen_device(); - auto row_along = Eigen::array({{1}}); - x_norm.device(place) = x.square().sum(row_along).sqrt(); - y_norm.device(place) = y.square().sum(row_along).sqrt(); + int cols = framework::product(in_x->dims()) / rows_x; + if (rows_x == rows_y) { - auto xy = (x * y).sum(Eigen::array({{1}})); - z.device(place) = xy / x_norm / y_norm; + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } else { - Eigen::DSizes bcast(rows_x, 1); - auto xy = (x * y.broadcast(bcast)).sum(row_along); - z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } } }; @@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel { auto* out_grad_y = context.Output(framework::GradVarName("Y")); auto* in_grad_z = context.Input(framework::GradVarName("Out")); - // convert Tensor to Eigen Tensor - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenMatrix::Reshape(*in_z, 1); - auto x_norm = EigenMatrix::Reshape(*in_x_norm, 1); - auto y_norm = EigenMatrix::Reshape(*in_y_norm, 1); - auto dz = EigenMatrix::Reshape(*in_grad_z, 1); - // compute gradident int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - Eigen::DSizes bcast_cols(1, cols); - auto z_bcast = z.broadcast(bcast_cols); - auto dz_bcast = dz.broadcast(bcast_cols); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); - auto& place = - *context.template device_context().eigen_device(); + if (rows_x == rows_y) { - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::Reshape(*out_grad_y, 1); - auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; - dy.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_y_norm->data(), in_x_norm->data(), in_y->data(), + in_x->data(), in_z->data(), in_grad_z->data(), + out_grad_y->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } } else { - Eigen::DSizes bcast_rows(rows_x, 1); - Eigen::DSizes bcast_rows_cols(rows_x, cols); - auto y_bcast = y.broadcast(bcast_rows); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); - auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) - .eval() - .broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimDxFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenVector::Flatten(*out_grad_y); - auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; - dy.device(place) = (dz_bcast * grad).sum(Eigen::array({{0}})); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, out_grad_y, static_cast(0)); + + math::CosSimDyFunctor functor; + functor(dev_ctx, in_x_norm->data(), in_y_norm->data(), + in_x->data(), in_y->data(), in_z->data(), + in_grad_z->data(), static_cast(rows_x), + static_cast(cols), out_grad_y->data()); } } } diff --git a/paddle/operators/detection_output_op.cc b/paddle/operators/detection_output_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea44cd32678d7e8a5836c1886cf9c1b4961970aa --- /dev/null +++ b/paddle/operators/detection_output_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_output_op.h" +namespace paddle { +namespace operators { + +class DetectionOutputOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DetectionOutputOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Loc", + "(Tensor) The input tensor of detection_output operator." + "The input predict locations" + "The format of input tensor is kNCHW. Where K is priorbox point " + "numbers," + "N is How many boxes are there on each point, " + "C is 4, H and W both are 1."); + AddInput("Conf", + "(Tensor) The input tensor of detection_output operator." + "The input priorbox confidence." + "The format of input tensor is kNCHW. Where K is priorbox point " + "numbers," + "N is How many boxes are there on each point, " + "C is the number of classes, H and W both are 1."); + AddInput("PriorBox", + "(Tensor) The input tensor of detection_output operator." + "The format of input tensor is the position and variance " + "of the boxes"); + AddOutput("Out", + "(Tensor) The output tensor of detection_output operator."); + AddAttr("background_label_id", "(int), The background class index."); + AddAttr("num_classes", "(int), The number of the classification."); + AddAttr("nms_threshold", + "(float), The Non-maximum suppression threshold."); + AddAttr("confidence_threshold", + "(float), The classification confidence threshold."); + AddAttr("top_k", "(int), The bbox number kept of the layer’s output."); + AddAttr("nms_top_k", + "(int), The bbox number kept of the NMS’s output."); + AddComment(R"DOC( + detection output for SSD(single shot multibox detector) + Apply the NMS to the output of network and compute the predict + bounding box location. The output’s shape of this layer could + be zero if there is no valid bounding box. + )DOC"); + } +}; + +class DetectionOutputOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Loc"), + "Input(X) of DetectionOutputOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Conf"), + "Input(X) of DetectionOutputOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PriorBox"), + "Input(X) of DetectionOutputOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of DetectionOutputOp should not be null."); + std::vector output_shape({1, 7}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(detection_output, ops::DetectionOutputOp, + ops::DetectionOutputOpMaker); +REGISTER_OP_CPU_KERNEL( + detection_output, + ops::DetectionOutputKernel, + ops::DetectionOutputKernel); diff --git a/paddle/operators/detection_output_op.cu.cc b/paddle/operators/detection_output_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a6560e0492c559afd06e5152c34fab545b7ce61 --- /dev/null +++ b/paddle/operators/detection_output_op.cu.cc @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/detection_output_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + detection_output, + ops::DetectionOutputKernel, + ops::DetectionOutputKernel); diff --git a/paddle/operators/detection_output_op.h b/paddle/operators/detection_output_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f8abd5b6406f05747b87fcfd464baeb705a7f7f2 --- /dev/null +++ b/paddle/operators/detection_output_op.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/math/detection_util.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/softmax.h" +#include "paddle/operators/strided_memcpy.h" +namespace paddle { +namespace operators { +template +inline void transpose_fun(const framework::ExecutionContext& context, + const framework::Tensor& src, + framework::Tensor* dst) { + int input_nums = src.dims()[0]; + int offset = 0; + for (int j = 0; j < input_nums; ++j) { + framework::Tensor in_p_tensor = src.Slice(j, j + 1); + std::vector shape_vec( + {in_p_tensor.dims()[0], in_p_tensor.dims()[1], in_p_tensor.dims()[3], + in_p_tensor.dims()[4], in_p_tensor.dims()[2]}); + framework::DDim shape(framework::make_ddim(shape_vec)); + framework::Tensor in_p_tensor_transpose; + in_p_tensor_transpose.mutable_data(shape, context.GetPlace()); + std::vector shape_axis({0, 1, 3, 4, 2}); + math::Transpose trans5; + trans5(context.template device_context(), in_p_tensor, + &in_p_tensor_transpose, shape_axis); + auto dst_stride = framework::stride(dst->dims()); + auto src_stride = framework::stride(in_p_tensor_transpose.dims()); + StridedMemcpy(context.device_context(), in_p_tensor_transpose.data(), + src_stride, in_p_tensor_transpose.dims(), dst_stride, + dst->data() + offset); + offset += in_p_tensor_transpose.dims()[4] * src_stride[4]; + } +} +template +class DetectionOutputKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_loc = context.Input("Loc"); + const framework::Tensor* in_conf = context.Input("Conf"); + const framework::Tensor* in_priorbox = + context.Input("PriorBox"); + auto* out = context.Output("Out"); + int num_classes = context.template Attr("num_classes"); + int top_k = context.template Attr("top_k"); + int nms_top_k = context.template Attr("nms_top_k"); + int background_label_id = context.template Attr("background_label_id"); + float nms_threshold = context.template Attr("nms_threshold"); + float confidence_threshold = + context.template Attr("confidence_threshold"); + size_t batch_size = in_conf->dims()[1]; + int conf_sum_size = in_conf->numel(); + // for softmax + std::vector conf_shape_softmax_vec( + {conf_sum_size / num_classes, num_classes}); + framework::DDim conf_shape_softmax( + framework::make_ddim(conf_shape_softmax_vec)); + // for knchw => nhwc + std::vector loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3], + in_loc->dims()[4], + in_loc->dims()[2] * in_loc->dims()[0]}); + std::vector conf_shape_vec( + {1, in_conf->dims()[1], in_conf->dims()[3], in_conf->dims()[4], + in_conf->dims()[2] * in_conf->dims()[0]}); + framework::DDim loc_shape(framework::make_ddim(loc_shape_vec)); + framework::DDim conf_shape(framework::make_ddim(conf_shape_vec)); + framework::Tensor loc_tensor; + framework::Tensor conf_tensor; + loc_tensor.mutable_data(loc_shape, context.GetPlace()); + conf_tensor.mutable_data(conf_shape, context.GetPlace()); + // for cpu + framework::Tensor loc_cpu; + framework::Tensor conf_cpu; + framework::Tensor priorbox_cpu; + const T* priorbox_data = in_priorbox->data(); + transpose_fun(context, *in_loc, &loc_tensor); + transpose_fun(context, *in_conf, &conf_tensor); + conf_tensor.Resize(conf_shape_softmax); + math::SoftmaxFunctor()( + context.template device_context(), &conf_tensor, + &conf_tensor); + T* loc_data = loc_tensor.data(); + T* conf_data = conf_tensor.data(); + if (platform::is_gpu_place(context.GetPlace())) { + loc_cpu.mutable_data(loc_tensor.dims(), platform::CPUPlace()); + framework::CopyFrom(loc_tensor, platform::CPUPlace(), + context.device_context(), &loc_cpu); + loc_data = loc_cpu.data(); + conf_cpu.mutable_data(conf_tensor.dims(), platform::CPUPlace()); + framework::CopyFrom(conf_tensor, platform::CPUPlace(), + context.device_context(), &conf_cpu); + conf_data = conf_cpu.data(); + priorbox_cpu.mutable_data(in_priorbox->dims(), platform::CPUPlace()); + framework::CopyFrom(*in_priorbox, platform::CPUPlace(), + context.device_context(), &priorbox_cpu); + priorbox_data = priorbox_cpu.data(); + } + // get decode bboxes + size_t num_priors = in_priorbox->numel() / 8; + std::vector>> all_decoded_bboxes; + for (size_t n = 0; n < batch_size; ++n) { + std::vector> decoded_bboxes; + for (size_t i = 0; i < num_priors; ++i) { + size_t prior_offset = i * 8; + size_t loc_pred_offset = n * num_priors * 4 + i * 4; + std::vector> prior_bbox_vec; + math::GetBBoxFromPriorData(priorbox_data + prior_offset, 1, + prior_bbox_vec); + std::vector> prior_bbox_var; + math::GetBBoxVarFromPriorData(priorbox_data + prior_offset, 1, + prior_bbox_var); + std::vector loc_pred_data; + for (size_t j = 0; j < 4; ++j) + loc_pred_data.push_back(*(loc_data + loc_pred_offset + j)); + math::BBox bbox = math::DecodeBBoxWithVar( + prior_bbox_vec[0], prior_bbox_var[0], loc_pred_data); + decoded_bboxes.push_back(bbox); + } + all_decoded_bboxes.push_back(decoded_bboxes); + } + std::vector>> all_indices; + int num_kept = math::GetDetectionIndices( + conf_data, num_priors, num_classes, background_label_id, batch_size, + confidence_threshold, nms_top_k, nms_threshold, top_k, + all_decoded_bboxes, &all_indices); + + if (num_kept <= 0) { + std::vector out_shape_vec({0, 0}); + framework::DDim out_shape(framework::make_ddim(out_shape_vec)); + out->Resize(out_shape); + return; + } + std::vector out_shape_vec({num_kept, 7}); + framework::DDim out_shape(framework::make_ddim(out_shape_vec)); + out->mutable_data(out_shape, context.GetPlace()); + framework::Tensor out_cpu; + T* out_data = out->data(); + if (platform::is_gpu_place(context.GetPlace())) { + out_cpu.mutable_data(out->dims(), platform::CPUPlace()); + out_data = out_cpu.data(); + } + math::GetDetectionOutput(conf_data, num_kept, num_priors, num_classes, + batch_size, all_indices, all_decoded_bboxes, + out_data); + if (platform::is_gpu_place(context.GetPlace())) { + framework::CopyFrom(out_cpu, platform::CUDAPlace(), + context.device_context(), out); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index c6228864d7ec042ff99e4521d1d707ba091e8ed5..b1957fb9ce6add8628cb206abf2c569d3f615c85 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel { } int frame_size = hidden_dims[1]; - math::hl_gru_value gru_value; + math::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); @@ -89,6 +90,10 @@ class GRUKernel : public framework::OpKernel { } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -101,9 +106,8 @@ class GRUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, - math::ActiveType(context.Attr("activation")), - math::ActiveType(context.Attr("gate_activation"))); + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); gru_value.prev_out_value = gru_value.output_value; } @@ -170,12 +174,12 @@ class GRUGradKernel : public framework::OpKernel { batch_hidden_grad.set_lod(batch_hidden->lod()); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); - math::hl_gru_value gru_value; + math::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); - math::hl_gru_grad gru_grad; + math::GRUMetaGrad gru_grad; if (weight_grad) { gru_grad.gate_weight_grad = weight_grad->mutable_data(context.GetPlace()); @@ -189,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel { auto batch_starts = batch_hidden_grad.lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -219,9 +227,8 @@ class GRUGradKernel : public framework::OpKernel { } math::GRUUnitGradFunctor::compute( - dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, - math::ActiveType(context.Attr("activation")), - math::ActiveType(context.Attr("gate_activation"))); + dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, + active_gate); } if (input_grad) { input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index 65f021d91931541b712bd46aebc06e68144b2af0..08b972a233aab8596a5ce7f74ea903df3b8ef0f2 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -38,7 +38,7 @@ class LoadOp : public framework::OperatorBase { out_var_name); auto *tensor = out_var->GetMutable(); - framework::DeserializeFromStream(fin, tensor); + DeserializeFromStream(fin, tensor); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index bf47879f772a3013bd7ce78c6f8a6aefe65298f9..7ebcfb9ab9f30e3b0f13d3646a59d008335b232d 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -9,13 +9,14 @@ if(WITH_GPU) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) - nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) - nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) + nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -23,13 +24,14 @@ else() cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) - cc_library(vol2col SRCS vol2col.cc DEPS device_context) + cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor) cc_library(context_project SRCS context_project.cc DEPS device_context math_function) - cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) + cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cos_sim_functor.cc b/paddle/operators/math/cos_sim_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..6af9f0fcd967b4e8e9e338c155d5a8ee2866fbfa --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimDyFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + for (size_t row_id = 0; row_id < rows; ++row_id) { + auto xy_norm_prod = x_norm[row_id] * y_norm[0]; + auto dz_data = dz[row_id]; + auto z_data = z[row_id]; + auto* x_data = x + cols * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm[0] * y_norm[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + } + } + } +}; + +template struct CosSimDyFunctor; +template struct CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.cu b/paddle/operators/math/cos_sim_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..6eb0a4ea4c5b86f84c93b97615255adf55e9e042 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) { + int grid_size = blockDim.x * gridDim.x; + T y_norm_data = y_norm[0]; + for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows; + row_id += grid_size) { + T xy_norm_prod = x_norm[row_id] * y_norm_data; + T dz_data = dz[row_id]; + T z_data = z[row_id]; + const T* x_data = x + cols * row_id; + T reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + T y_norm_square = y_norm_data * y_norm_data; + T reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + platform::CudaAtomicAdd(dy + i, dy_data); + } + } +} + +template +struct CosSimDyFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + const int block_size = 512; + dim3 threads(block_size, 1); + dim3 grid(1, (rows + block_size - 1) / block_size); + CosSimDyKernel<<>>( + x_norm, y_norm, x, y, z, dz, rows, cols, dy); + } +}; + +template struct CosSimDyFunctor; +template struct CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.h b/paddle/operators/math/cos_sim_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..aae8ab5b7a937c016e8a45e34b22aba7a1df3066 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimFunctor { + CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto* x = x_ + cols_ * row_id; + T xx = 0, xy = 0, yy = 0; + if (same_row) { + auto* y = y_ + cols_ * row_id; + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + y_norm_[row_id] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } else { // This can be wrote in a better way. + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y_[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + if (row_id == 0) y_norm_[0] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } + } + + T* x_norm_; + T* y_norm_; + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +template +struct CosSimGradFunctor { + CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + + auto* dx = dx_ + cols_ * row_id; + auto* x = x_ + cols_ * row_id; + auto* y = y_ + cols_ * row_id; + + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto reciprocal_x_norm_square = 1 / x_norm_square; + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDxFunctor { + CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto xy_norm_prod = x_norm_[row_id] * y_norm_[0]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + auto* x = x_ + cols_ * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto* dx = dx_ + cols_ * row_id; + auto reciprocal_x_norm_square = 1 / x_norm_square; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDyFunctor { + void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm, + const T* x, const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/gru_cpu_kernel.h b/paddle/operators/math/detail/gru_cpu_kernel.h index 4c67dec9cbeb48f400f79f5ed7ba3c939fa2540c..a61b232f4275d93cae1d9a71d49a779216c3555b 100644 --- a/paddle/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/operators/math/detail/gru_cpu_kernel.h @@ -28,7 +28,7 @@ template void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; @@ -56,7 +56,7 @@ template void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -83,7 +83,7 @@ template void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { #ifdef __AVX__ __m256 r_value_update_gate; __m256 r_value_reset_gate; @@ -113,7 +113,7 @@ template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { #ifdef __AVX__ __m256 r_value_update_gate; __m256 r_value_frame_state; @@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, template inline void forward_reset_output(OpResetOutput op_reset_output, - hl_gru_value value, int frame_size, - int batch_size, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( @@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output, template inline void forward_final_output(OpFinalOutput op_final_output, - hl_gru_value value, int frame_size, - int batch_size, - activation_mode_t active_node) { + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { for (int b = 0; b < batch_size; b++) { if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, @@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { T r_update_gate_value; T r_update_gate_grad; T r_frame_state_value; @@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { T r_update_gate_value; T r_update_gate_grad; T r_reset_gate_value; @@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, template inline void backward_state_grad(OpStateGrad op_state_grad, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { for (int b = 0; b < batch_size; b++) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_state_grad( @@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad, template inline void backward_reset_grad(OpResetGrad op_reset_grad, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_reset_grad( diff --git a/paddle/operators/math/detail/gru_gpu_kernel.h b/paddle/operators/math/detail/gru_gpu_kernel.h index d2edcb7f258b387530799b967fc0fff61acc5b83..1783d46096858c874b27ce75760342082835b180 100644 --- a/paddle/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/operators/math/detail/gru_gpu_kernel.h @@ -19,8 +19,6 @@ limitations under the License. */ #include "paddle/platform/cuda_helper.h" #include "paddle/platform/device_context.h" -#include - namespace paddle { namespace operators { namespace math { @@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; diff --git a/paddle/operators/math/detail/gru_kernel.h b/paddle/operators/math/detail/gru_kernel.h index acd84be01db9ddaf06d165d8be353b253f324dd2..4d8245cb5d03b33edbda5d8350be02b4fa87ab95 100644 --- a/paddle/operators/math/detail/gru_kernel.h +++ b/paddle/operators/math/detail/gru_kernel.h @@ -30,7 +30,7 @@ class gru_resetOutput { public: HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, T &prev_out, T &value_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { value_update_gate = activation(value_update_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate); value_reset_output = prev_out * value_reset_gate; @@ -43,7 +43,7 @@ class gru_resetOutput { HOSTDEVICE void operator()(__m256 &value_update_gate, __m256 &value_reset_gate, __m256 &prev_out, __m256 &value_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { value_update_gate = activation(value_update_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate); value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); @@ -57,7 +57,7 @@ class gru_finalOutput { public: HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, T &prev_out, T &value_output, - activation_mode_t act_input) { + ActivationType act_input) { value_frame_state = activation(value_frame_state, act_input); value_output = prev_out - (value_update_gate * prev_out) + (value_update_gate * value_frame_state); @@ -69,8 +69,7 @@ class gru_finalOutput { static const bool avx = true; HOSTDEVICE void operator()(__m256 &value_update_gate, __m256 &value_frame_state, __m256 &prev_out, - __m256 &value_output, - activation_mode_t act_input) { + __m256 &value_output, ActivationType act_input) { value_frame_state = activation(value_frame_state, act_input); value_output = _mm256_add_ps( _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), @@ -89,7 +88,7 @@ class gru_stateGrad { HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, T &value_frame_state, T &grad_frame_state, T &value_prev_out, T &grad_prev_out, - T &grad_output, activation_mode_t act_input) { + T &grad_output, ActivationType act_input) { grad_update_gate = (grad_output * value_frame_state); grad_update_gate -= (grad_output * value_prev_out); grad_prev_out -= (grad_output * value_update_gate); @@ -107,7 +106,7 @@ class gru_stateGrad { __m256 &value_frame_state, __m256 &grad_frame_state, __m256 &value_prev_out, __m256 &grad_prev_out, __m256 &grad_output, - activation_mode_t act_input) { + ActivationType act_input) { grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); grad_update_gate = _mm256_sub_ps( grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); @@ -128,7 +127,7 @@ class gru_resetGrad { HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, T &value_reset_gate, T &grad_reset_gate, T &value_prev_out, T &grad_prev_out, - T &grad_reset_output, activation_mode_t act_gate) { + T &grad_reset_output, ActivationType act_gate) { grad_reset_gate = (grad_reset_output * value_prev_out); grad_prev_out += (grad_reset_output * value_reset_gate); grad_update_gate = @@ -144,7 +143,7 @@ class gru_resetGrad { __m256 &grad_update_gate, __m256 &value_reset_gate, __m256 &grad_reset_gate, __m256 &value_prev_out, __m256 &grad_prev_out, __m256 &grad_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); grad_prev_out = _mm256_add_ps( grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); diff --git a/paddle/operators/math/detection_util.h b/paddle/operators/math/detection_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e3a3ef2badc37924d866ded8ee7a7338fbc4b2d2 --- /dev/null +++ b/paddle/operators/math/detection_util.h @@ -0,0 +1,300 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +template +struct BBox { + BBox(T x_min, T y_min, T x_max, T y_max) + : x_min(x_min), + y_min(y_min), + x_max(x_max), + y_max(y_max), + is_difficult(false) {} + + BBox() {} + + T get_width() const { return x_max - x_min; } + + T get_height() const { return y_max - y_min; } + + T get_center_x() const { return (x_min + x_max) / 2; } + + T get_center_y() const { return (y_min + y_max) / 2; } + + T get_area() const { return get_width() * get_height(); } + + // coordinate of bounding box + T x_min; + T y_min; + T x_max; + T y_max; + // whether difficult object (e.g. object with heavy occlusion is difficult) + bool is_difficult; +}; +// KNCHW ==> NHWC +// template +template +void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, + std::vector>& bbox_vec); +template +void GetBBoxVarFromPriorData(const T* prior_data, const size_t num, + std::vector>& var_vec); +template +BBox DecodeBBoxWithVar(BBox& prior_bbox, + const std::vector& prior_bbox_var, + const std::vector& loc_pred_data); +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2); +template +bool SortScorePairDescend(const std::pair>& pair1, + const std::pair>& pair2); +template +T jaccard_overlap(const BBox& bbox1, const BBox& bbox2); + +template +void ApplyNmsFast(const std::vector>& bboxes, const T* conf_score_data, + size_t class_idx, size_t top_k, T conf_threshold, + T nms_threshold, size_t num_priors, size_t num_classes, + std::vector* indices); +template +int GetDetectionIndices( + const T* conf_data, const size_t num_priors, const size_t num_classes, + const size_t background_label_id, const size_t batch_size, + const T conf_threshold, const size_t nms_top_k, const T nms_threshold, + const size_t top_k, + const std::vector>>& all_decoded_bboxes, + std::vector>>* all_detection_indices); +template +BBox ClipBBox(const BBox& bbox); +template +void GetDetectionOutput( + const T* conf_data, const size_t num_kept, const size_t num_priors, + const size_t num_classes, const size_t batch_size, + const std::vector>>& all_indices, + const std::vector>>& all_decoded_bboxes, T* out_data); +template +void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, + std::vector>& bbox_vec) { + size_t out_offset = bbox_vec.size(); + bbox_vec.resize(bbox_vec.size() + num_bboxes); + for (size_t i = 0; i < num_bboxes; ++i) { + BBox bbox; + bbox.x_min = *(prior_data + i * 8); + bbox.y_min = *(prior_data + i * 8 + 1); + bbox.x_max = *(prior_data + i * 8 + 2); + bbox.y_max = *(prior_data + i * 8 + 3); + bbox_vec[out_offset + i] = bbox; + } +} +template +void GetBBoxVarFromPriorData(const T* prior_data, const size_t num, + std::vector>& var_vec) { + size_t out_offset = var_vec.size(); + var_vec.resize(var_vec.size() + num); + for (size_t i = 0; i < num; ++i) { + std::vector var; + var.push_back(*(prior_data + i * 8 + 4)); + var.push_back(*(prior_data + i * 8 + 5)); + var.push_back(*(prior_data + i * 8 + 6)); + var.push_back(*(prior_data + i * 8 + 7)); + var_vec[out_offset + i] = var; + } +} +template +BBox DecodeBBoxWithVar(BBox& prior_bbox, + const std::vector& prior_bbox_var, + const std::vector& loc_pred_data) { + T prior_bbox_width = prior_bbox.get_width(); + T prior_bbox_height = prior_bbox.get_height(); + T prior_bbox_center_x = prior_bbox.get_center_x(); + T prior_bbox_center_y = prior_bbox.get_center_y(); + + T decoded_bbox_center_x = + prior_bbox_var[0] * loc_pred_data[0] * prior_bbox_width + + prior_bbox_center_x; + T decoded_bbox_center_y = + prior_bbox_var[1] * loc_pred_data[1] * prior_bbox_height + + prior_bbox_center_y; + T decoded_bbox_width = + std::exp(prior_bbox_var[2] * loc_pred_data[2]) * prior_bbox_width; + T decoded_bbox_height = + std::exp(prior_bbox_var[3] * loc_pred_data[3]) * prior_bbox_height; + + BBox decoded_bbox; + decoded_bbox.x_min = decoded_bbox_center_x - decoded_bbox_width / 2; + decoded_bbox.y_min = decoded_bbox_center_y - decoded_bbox_height / 2; + decoded_bbox.x_max = decoded_bbox_center_x + decoded_bbox_width / 2; + decoded_bbox.y_max = decoded_bbox_center_y + decoded_bbox_height / 2; + + return decoded_bbox; +} +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} +template +T jaccard_overlap(const BBox& bbox1, const BBox& bbox2) { + if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || + bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { + return 0.0; + } else { + T inter_x_min = std::max(bbox1.x_min, bbox2.x_min); + T inter_y_min = std::max(bbox1.y_min, bbox2.y_min); + T interX_max = std::min(bbox1.x_max, bbox2.x_max); + T interY_max = std::min(bbox1.y_max, bbox2.y_max); + + T inter_width = interX_max - inter_x_min; + T inter_height = interY_max - inter_y_min; + T inter_area = inter_width * inter_height; + + T bbox_area1 = bbox1.get_area(); + T bbox_area2 = bbox2.get_area(); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } +} + +template +void ApplyNmsFast(const std::vector>& bboxes, const T* conf_score_data, + size_t class_idx, size_t top_k, T conf_threshold, + T nms_threshold, size_t num_priors, size_t num_classes, + std::vector* indices) { + std::vector> scores; + for (size_t i = 0; i < num_priors; ++i) { + size_t conf_offset = i * num_classes + class_idx; + if (conf_score_data[conf_offset] > conf_threshold) + scores.push_back(std::make_pair(conf_score_data[conf_offset], i)); + } + std::stable_sort(scores.begin(), scores.end(), + SortScorePairDescend); + if (top_k > 0 && top_k < scores.size()) scores.resize(top_k); + while (scores.size() > 0) { + const size_t idx = scores.front().second; + bool keep = true; + for (size_t i = 0; i < indices->size(); ++i) { + if (keep) { + const size_t saved_idx = (*indices)[i]; + T overlap = jaccard_overlap(bboxes[idx], bboxes[saved_idx]); + keep = overlap <= nms_threshold; + } else { + break; + } + } + if (keep) indices->push_back(idx); + scores.erase(scores.begin()); + } +} +template +int GetDetectionIndices( + const T* conf_data, const size_t num_priors, const size_t num_classes, + const size_t background_label_id, const size_t batch_size, + const T conf_threshold, const size_t nms_top_k, const T nms_threshold, + const size_t top_k, + const std::vector>>& all_decoded_bboxes, + std::vector>>* all_detection_indices) { + int total_keep_num = 0; + for (size_t n = 0; n < batch_size; ++n) { + const std::vector>& decoded_bboxes = all_decoded_bboxes[n]; + size_t num_detected = 0; + std::map> indices; + size_t conf_offset = n * num_priors * num_classes; + for (size_t c = 0; c < num_classes; ++c) { + if (c == background_label_id) continue; + ApplyNmsFast(decoded_bboxes, conf_data + conf_offset, c, nms_top_k, + conf_threshold, nms_threshold, num_priors, num_classes, + &(indices[c])); + num_detected += indices[c].size(); + } + if (top_k > 0 && num_detected > top_k) { + // std::vector> score_index_pairs; + std::vector>> score_index_pairs; + for (size_t c = 0; c < num_classes; ++c) { + const std::vector& label_indices = indices[c]; + for (size_t i = 0; i < label_indices.size(); ++i) { + size_t idx = label_indices[i]; + score_index_pairs.push_back( + std::make_pair((conf_data + conf_offset)[idx * num_classes + c], + std::make_pair(c, idx))); + } + } + std::sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(top_k); + std::map> new_indices; + for (size_t i = 0; i < score_index_pairs.size(); ++i) { + size_t label = score_index_pairs[i].second.first; + size_t idx = score_index_pairs[i].second.second; + new_indices[label].push_back(idx); + } + all_detection_indices->push_back(new_indices); + total_keep_num += top_k; + } else { + all_detection_indices->push_back(indices); + total_keep_num += num_detected; + } + } + return total_keep_num; +} +template +BBox ClipBBox(const BBox& bbox) { + T one = static_cast(1.0); + T zero = static_cast(0.0); + BBox clipped_bbox; + clipped_bbox.x_min = std::max(std::min(bbox.x_min, one), zero); + clipped_bbox.y_min = std::max(std::min(bbox.y_min, one), zero); + clipped_bbox.x_max = std::max(std::min(bbox.x_max, one), zero); + clipped_bbox.y_max = std::max(std::min(bbox.y_max, one), zero); + return clipped_bbox; +} +template +void GetDetectionOutput( + const T* conf_data, const size_t num_kept, const size_t num_priors, + const size_t num_classes, const size_t batch_size, + const std::vector>>& all_indices, + const std::vector>>& all_decoded_bboxes, T* out_data) { + size_t count = 0; + for (size_t n = 0; n < batch_size; ++n) { + for (std::map>::const_iterator it = + all_indices[n].begin(); + it != all_indices[n].end(); ++it) { + size_t label = it->first; + const std::vector& indices = it->second; + const std::vector>& decoded_bboxes = all_decoded_bboxes[n]; + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + size_t conf_offset = n * num_priors * num_classes + idx * num_classes; + out_data[count * 7] = n; + out_data[count * 7 + 1] = label; + out_data[count * 7 + 2] = (conf_data + conf_offset)[label]; + BBox clipped_bbox = ClipBBox(decoded_bboxes[idx]); + out_data[count * 7 + 3] = clipped_bbox.x_min; + out_data[count * 7 + 4] = clipped_bbox.y_min; + out_data[count * 7 + 5] = clipped_bbox.x_max; + out_data[count * 7 + 6] = clipped_bbox.y_max; + ++count; + } + } + } +} +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/gru_compute.cc b/paddle/operators/math/gru_compute.cc index d570c68cd458914c8951c4ce50a02e3c5b1acab0..101ab859624869bf34d171cd42d46d0c5bdac29c 100644 --- a/paddle/operators/math/gru_compute.cc +++ b/paddle/operators/math/gru_compute.cc @@ -21,9 +21,9 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::CPUDeviceContext &context, - hl_gru_value value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { #ifndef __NVCC__ if (value.prev_out_value) { math::gemm( @@ -51,10 +51,10 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::CPUDeviceContext &context, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node); diff --git a/paddle/operators/math/gru_compute.cu b/paddle/operators/math/gru_compute.cu index dd518cd1e4bea52f0d463150114feed3ceea0ccb..d5a0e630ea0eadea990988c3170395c842a91900 100644 --- a/paddle/operators/math/gru_compute.cu +++ b/paddle/operators/math/gru_compute.cu @@ -21,9 +21,9 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::CUDADeviceContext &context, - hl_gru_value value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -88,10 +88,10 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::CUDADeviceContext &context, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; diff --git a/paddle/operators/math/gru_compute.h b/paddle/operators/math/gru_compute.h index ca1343cb2c5c1eb8da92c2f06b25902c1c2fe8b3..bf69147b506661692a6d71823043cd3506ea8b5d 100644 --- a/paddle/operators/math/gru_compute.h +++ b/paddle/operators/math/gru_compute.h @@ -11,7 +11,7 @@ limitations under the License. */ #pragma once -#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -19,9 +19,8 @@ namespace paddle { namespace operators { namespace math { -// TODO(guosheng): refine code style in gru_compute template -struct hl_gru_value { +struct GRUMetaValue { T *gate_weight; T *state_weight; T *gate_value; @@ -31,7 +30,7 @@ struct hl_gru_value { }; template -struct hl_gru_grad { +struct GRUMetaGrad { T *gate_weight_grad; T *state_weight_grad; T *gate_grad; @@ -42,18 +41,18 @@ struct hl_gru_grad { template struct GRUUnitFunctor { - static void compute(const DeviceContext &context, hl_gru_value value, + static void compute(const DeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate); + const detail::ActivationType active_node, + const detail::ActivationType active_gate); }; template struct GRUUnitGradFunctor { - static void compute(const DeviceContext &context, hl_gru_value value, - hl_gru_grad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate); + static void compute(const DeviceContext &context, GRUMetaValue value, + GRUMetaGrad grad, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate); }; } // namespace math diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 954762f92286fe13bd2c08ec03c3ac96bb663cca..e1ad6b64d201ef99d83eaa2a821356008dcc9c8e 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -22,14 +22,6 @@ namespace paddle { namespace operators { namespace math { -typedef enum { - HL_ACTIVATION_SIGMOID = 0, - HL_ACTIVATION_RELU = 1, - HL_ACTIVATION_TANH = 2, - HL_ACTIVATION_LINEAR = 3, - HL_ACTIVATION_END -} activation_mode_t; - template struct LstmMetaValue { T *gate_value; @@ -54,20 +46,6 @@ struct LstmMetaGrad { T *check_og_grad; }; -inline activation_mode_t ActiveType(const std::string &type) { - if (type == "sigmoid") { - return HL_ACTIVATION_SIGMOID; - } else if (type == "relu") { - return HL_ACTIVATION_RELU; - } else if (type == "tanh") { - return HL_ACTIVATION_TANH; - } else if (type == "linear" || type == "identity" || type == "") { - return HL_ACTIVATION_LINEAR; - } else { - PADDLE_THROW("Do not support activation type."); - } -} - template class LstmUnitFunctor { public: diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index ab758d1e7fd8ab361948b28e8cb735b9a742a339..8a1ebb58c26578f076bf243adfbd51d10c682b99 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/selected_rows_functor.h" +#include + #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { @@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = FindPos(merge_rows, input_rows[i]); + for (int64_t j = 0; j < input_width; j++) { + out_data[out_i * input_width + j] += input_data[i * input_width + j]; + } + } + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +struct UpdateToTensor { + void operator()(const platform::CPUDeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::ADD: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUB: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] -= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUBBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] - + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + case ScatterOps::MUL: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] *= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIV: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] /= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIVBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] / + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + } + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 9fddd97a36f7fdb6628d6eeb192cb216fdae3e5b..0ee456f9bc61436bd0f2f8ef20dd1654e7e56d56 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/cuda_helper.h" @@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; + +namespace scatter { + +template +__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, + T* out, const int64_t* out_rows, + size_t out_rows_size, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t out_idx; + + if (tid == 0) { + for (size_t i = 0; i < out_rows_size; i++) { + if (input_rows[ty] == out_rows[i]) { + out_idx = i; + } + } + } + + __syncthreads(); + + input += ty * row_numel; + out += out_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(out + index, input[index]); + } +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CUDADeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, input_rows.size()); + + MergeAddKernel< + T, 256><<(context) + .stream()>>>(input_data, input.rows().data(), out_data, + out.rows().data(), out.rows().size(), + input_width); + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +__global__ void UpdateToTensorKernel(const T* selected_rows, + const int64_t* rows, const ScatterOps& op, + T* tensor_out, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index]; + } + break; + case ScatterOps::ADD: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] += selected_rows[index]; + } + break; + case ScatterOps::SUB: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] -= selected_rows[index]; + } + break; + case ScatterOps::SUBBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] - tensor_out[index]; + } + break; + case ScatterOps::MUL: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] *= selected_rows[index]; + } + break; + case ScatterOps::DIV: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] /= selected_rows[index]; + } + break; + case ScatterOps::DIVBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] / tensor_out[index]; + } + break; + } +} + +template +struct UpdateToTensor { + void operator()(const platform::CUDADeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + // NOTE: Use SelectedRowsAddToTensor for better performance + // no additional MergeAdd called. + MergeAdd merge_func; + auto merged_in1 = merge_func(context, input1); + + auto in1_height = merged_in1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = merged_in1.value(); + auto& in1_rows = merged_in1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.template data(); + auto* in2_data = input2->data(); + + dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); + dim3 grid(1, in1_rows.size()); + UpdateToTensorKernel<<< + grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, + in2_data, in1_row_numel); + } +}; +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 1149075abf16547a120ac8928c45b4972409fc72..09d4631905f90f78772368ad71b11826877bdc34 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" +#define INLINE_FOR2(sizei, sizej) \ + for (int64_t i = 0; i < sizei; i++) \ + for (int64_t j = 0; j < sizej; j++) + namespace paddle { namespace operators { namespace math { @@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2); }; +namespace scatter { +// functors for manuplating SelectedRows data +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input); +}; + +template +struct Add { + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 + e_in2; + return out; + } +}; + +template +struct Mul { + // multiply two SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 * e_in2; + return out; + } + // multiply scalar to SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const T input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + e_out.device(*context.eigen_device()) = input2 * e_in1; + return out; + } +}; + +enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; + +// out = seleted_rows_in / tensor +template +struct UpdateToTensor { + void operator()(const DeviceContext& context, const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2); +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e5ef0740b6f385de7f17a3a419000cb8c897d986..b37269b471b4d71b42c41641fd14c7a64d2719d6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -116,9 +116,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { auto height = dout_tensor.dims()[0]; auto slice = dx_tensor.Slice(0, static_cast(height)); framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice); - if (dx_tensor.dims()[0] < height) { + if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( - static_cast(height), static_cast(dout_tensor.dims()[0])); + static_cast(height), static_cast(dx_tensor.dims()[0])); math::set_constant(dev_ctx, &rest_tensor, 0.0f); } } diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h index eaa36aa1aea53e0b37ef6c578d8bb1cda230ded0..552b48f608b7e0248f03dbea940a83f112a67712 100644 --- a/paddle/operators/sum_op.h +++ b/paddle/operators/sum_op.h @@ -37,11 +37,11 @@ class SumKernel : public framework::OpKernel { bool in_place = out_var == in_vars[0]; if (out_var->IsType()) { - auto *out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - + auto *out = context.Output("Out"); + if (!in_place) { + out->mutable_data(context.GetPlace()); + } auto result = EigenVector::Flatten(*out); - if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 53e38ec70336ca7f2d7c142e5fb1bbe427ab2957..d5ff3e3fce29b1a888b2cd4d307c2655669e3e4c 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp { auto &x_array = x->Get(); auto *out = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out != nullptr, "Out must be set"); - auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { + auto *out_tensor = out->GetMutable(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index dfef2c16d8f2277d57cbcfe51d108402e518799b..2b366e6383d23e2d31a194edd04412892a8311eb 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { std::unique_ptr eigen_device_; }; +template +struct DefaultDeviceContextType; + +template <> +struct DefaultDeviceContextType { + using TYPE = CPUDeviceContext; +}; + #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; @@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +template <> +struct DefaultDeviceContextType { + using TYPE = CUDADeviceContext; +}; + class CUDNNDeviceContext : public CUDADeviceContext { public: explicit CUDNNDeviceContext(CUDAPlace place); @@ -125,6 +138,13 @@ class DeviceContextPool { /*! \brief Return handle of single device context. */ const platform::DeviceContext* Get(const platform::Place& place); + template + const typename DefaultDeviceContextType::TYPE* GetByPlace( + const Place& place) { + return reinterpret_cast< + const typename DefaultDeviceContextType::TYPE*>(Get(place)); + } + private: static DeviceContextPool* pool; constexpr static int LEFT_SHIFT = 8; diff --git a/paddle/platform/for_range.h b/paddle/platform/for_range.h index 5427aa28238d6b46eb72d1fb49303dce3d871d7d..694a66d9ac4eb6ad02daf1931806fa1287de7cab 100644 --- a/paddle/platform/for_range.h +++ b/paddle/platform/for_range.h @@ -62,7 +62,7 @@ struct ForRange { template inline void operator()(Function func) const { - constexpr size_t num_threads = 1024; + constexpr int num_threads = 1024; int block_size = limit_ <= num_threads ? limit_ : num_threads; int grid_size = (limit_ + num_threads - 1) / num_threads; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index d25eaa689f4a4baa951db5c61bbf99288e365ee1..76b5c502cc48431a4e9b13b07505978884576e1d 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" namespace paddle { @@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); +template +struct PlaceVisitorWrapper + : public boost::static_visitor { + const Visitor &visitor_; + explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {} + + typename Visitor::result_type operator()(const CPUPlace &cpu) const { + return visitor_(cpu); + } + + typename Visitor::result_type operator()(const CUDAPlace &cuda) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); + return typename Visitor::result_type(); +#endif + } +}; + +template +typename Visitor::result_type VisitPlace(const Place &place, + const Visitor &visitor) { + return boost::apply_visitor(PlaceVisitorWrapper(visitor), place); +} + } // namespace platform } // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 6afed7eec7001b646d55cef0bc3f59782b80b15f..7b374307071d2da91a677361b404448f1a3816b0 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -3,6 +3,9 @@ if(WITH_PYTHON) SRCS pybind.cc exception.cc protobuf.cc const_value.cc DEPS pybind python backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) + if(NOT APPLE AND NOT ANDROID) + target_link_libraries(paddle_pybind rt) + endif(NOT APPLE AND NOT ANDROID) endif(WITH_PYTHON) if(WITH_DOC) diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index 4d5e73e2c266b301de4f19e09be7ab4009c936d3..6b4290972bade585d1a0c2ae919a2e712bdf308c 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -77,10 +77,10 @@ struct CastToPyBufferImpl { } else if (paddle::platform::is_cpu_place(tensor.place())) { dst_tensor = tensor; } - return py::buffer_info( - dst_tensor.mutable_data(dst_tensor.place()), - sizeof(CUR_TYPE), py::format_descriptor::format(), - (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); + return py::buffer_info(dst_tensor.data(), sizeof(CUR_TYPE), + py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), + dims_outside, strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index e43b9c218a3ecb9e7f20fb7e8b14a85a29947eef..92039ec6b05d224e702f0ba5dc05c057a492287e 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -178,7 +178,7 @@ EOF # run paddle version to install python packages first RUN apt-get update &&\ ${NCCL_DEPS}\ - apt-get install -y wget python-pip dmidecode && pip install -U pip && \ + apt-get install -y wget python-pip dmidecode python-tk && pip install -U pip && \ pip install /*.whl; apt-get install -f -y && \ apt-get clean -y && \ rm -f /*.whl && \ diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index a94bc01b358c508132eb85920a2d4c0aa934dd51..8a352b0078d701f797f7202c85bd0e08201ac9b8 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -71,9 +71,7 @@ function threads_config() { # auto set OMP_NUM_THREADS and MKL_NUM_THREADS # according to trainer_count and total processors # only when MKL enabled - if [ "@WITH_MKL@" == "OFF" ]; then - return 0 - fi + # auto set OPENBLAS_NUM_THREADS when do not use MKL processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l` trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs` if [ -z $trainers ]; then @@ -83,12 +81,19 @@ function threads_config() { if [ $threads -eq 0 ]; then threads=1 fi - if [ -z "$OMP_NUM_THREADS" ]; then - export OMP_NUM_THREADS=$threads - fi - if [ -z "$MKL_NUM_THREADS" ]; then - export MKL_NUM_THREADS=$threads + if [ "@WITH_MKL@" == "ON" ]; then + if [ -z "$OMP_NUM_THREADS" ]; then + export OMP_NUM_THREADS=$threads + fi + if [ -z "$MKL_NUM_THREADS" ]; then + export MKL_NUM_THREADS=$threads + fi + else + if [ -z "$OPENBLAS_NUM_THREADS" ]; then + export OPENBLAS_NUM_THREADS=$threads + fi fi + } PADDLE_CONF_HOME="$HOME/.config/paddle" @@ -150,7 +155,7 @@ fi case "$1" in "train") threads_config $@ - # echo $OMP_NUM_THREADS $MKL_NUM_THREADS + # echo $OMP_NUM_THREADS $MKL_NUM_THREADS $OPENBLAS_NUM_THREADS ${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2} ;; "merge_model") diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 634388094c804827657dc83d5c205e680625b156..7bdddeaabec733ef26b3f766c6437f5c53d65044 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -44,7 +44,7 @@ __all__ = ['train', 'test', 'valid'] DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' +DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index c72b5730695dbc4f772015f1fb8dec6814cd1837..225b41c5043b5792abb90bbad53cbbfce9a3156e 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -36,7 +36,7 @@ def __read_gflags_from_env__(): """ import sys import core - read_env_flags = ['use_pinned_memory'] + read_env_flags = ['use_pinned_memory', 'check_nan_inf'] if core.is_compile_gpu(): read_env_flags.append('fraction_of_gpu_memory_to_use') core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 6966cc75804b6b5a49ceb45a26994c23d2936bdb..f11c83f59c930784ca355acc3193c3c352db10e5 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -5,14 +5,17 @@ import collections __all__ = ['append_backward'] -def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, - end_idx=None): +def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): + """ + Traverse all ops in op_descs[begin_idx : end_idx], + if any op has inputs/outputs named "old_name", rename it as 'new_name' + """ if begin_idx is None: begin_idx = 0 if end_idx is None: - end_idx = len(op_desc_list) + end_idx = len(op_descs) for i in range(begin_idx, end_idx): - op_desc = op_desc_list[i] + op_desc = op_descs[i] if isinstance(op_desc, tuple): op_desc = op_desc[0] op_desc.rename_input(old_name, new_name) @@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, def _create_op_desc_(op_type, inputs, outputs, attrs): + """ + Create a C++ OpDesc object with specified inputs, outputs and attributes. + """ op_desc = core.OpDesc() op_desc.set_type(op_type) for para, args in inputs.iteritems(): @@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): return op_desc -def _infer_var_data_type_(var_name, block): - grad_var = block.desc.find_var(var_name.encode("ascii")) - fwd_name = _strip_grad_suffix_(var_name.encode("ascii")) +def _infer_var_data_type_(grad_var_name, block): + """ + Infer the data type of given grad variable + """ + grad_var = block.desc.find_var(grad_var_name.encode("ascii")) + fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii")) if block.desc.has_var_recursive(fwd_name): fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) grad_var.set_dtype(fwd_var.dtype()) @@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block): def _all_in_set_(cands, s): + """ + Test if all elements of 'cands' are in set 's' + """ for c in cands: if not c in s: return False @@ -52,18 +64,29 @@ def _all_in_set_(cands, s): def _strip_grad_suffix_(name): + """ + Strip the grad suffix from the given varibale name + e.g. x@GRAD ==> x + y@GRAD@RENAME@1 ==> y + """ pos = name.find(core.grad_var_suffix()) return name[:pos] if pos != -1 else name def _append_grad_suffix_(name): + """ + Append grad suffix to the given variable name + e.g. x ==> x@GRAD + """ return name + core.grad_var_suffix() def _addup_repetitive_outputs_(op_descs): - # In backward part, an variable my be the output of more than one ops. - # In this case, the variable should be the accumulation of all the outputs. - # We adopt adding `sum_op`s to implement the accumulate. + """ + In backward part, an variable may be the output of more than one ops. + In this case, the variable should be the accumulation of all the outputs. + `sum_op`s are added to implement the accumulate. + """ pending_sum_ops = [] var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) @@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs): def _remove_no_grad_branch_(op_descs, no_grad_set): + """ + Remove unnecessary grad ops + A grad op can be removed in two cases: + 1. all outputs of the grad op are in 'no_grad_set' + 2. (TODO) all grad inputs of the grad op are in 'no_grad_set' + """ # Remove ops whose outputs are all in no_grad_dict op_descs = filter( lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), @@ -133,6 +162,21 @@ def _append_backward_ops_(target, no_grad_dict, grad_to_var, callback=None): + """ + Create all grad ops, and insert them into given block + + Args: + target(Variable): the target variable of forward pass + block(Block): the block where forward ops are + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of varibale names. These varibales have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + """ + # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program for op in reversed(block.ops): @@ -145,6 +189,7 @@ def _append_backward_ops_(target, no_grad_dict, grad_to_var, callback) grad_sub_block_list.append(grad_sub_block.desc) + # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, no_grad_dict[block.idx], grad_sub_block_list) grad_op_descs.extend(grad_op_desc) @@ -170,6 +215,20 @@ def _append_backward_ops_(target, def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): + """ + Create new variables required by backward pass. + + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + """ for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) if op_desc.has_attr("sub_block"): @@ -197,18 +256,18 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def append_backward(loss, parameter_list=None, no_grad_set=None): """ - Create and add gradient Operators in BlockDesc to compute - gradients of `loss` for parameters in parameter_list - - :param loss: an variable generated by cost function. - :type loss: Variable - :param no_grad_dict: variable that should not create gradient - :type no_grad_dict: set - :param parameter_list: parameters that need to compute gradient and - update to optimize the lost. - :type: list - :return: list of (parameters, gradients) pair. - :rtype: list[Variable] + Append backward part to main_program + + Args: + loss(Variable): The variable generated by cost function. + parameter_list(list): Parameters that need to be updated by optimizer. + If None, it means all parameters need to be updated. + no_grad_set(set): Variables that have no gradients in Block 0. + If None, the set will be generated inside the function and + contains all variables with `step_gradient=True` from all blocks. + + Return: + (list[Variable]): list of (parameters, gradients) pair. """ assert isinstance(loss, framework.Variable) diff --git a/python/paddle/v2/fluid/data_feeder.py b/python/paddle/v2/fluid/data_feeder.py index 30a542af212926c93381aade426e25f2117e4662..24036c3e75b9594ba58cccb02825ab8020d1e107 100644 --- a/python/paddle/v2/fluid/data_feeder.py +++ b/python/paddle/v2/fluid/data_feeder.py @@ -3,7 +3,7 @@ import core import numpy import six.moves as six -from framework import Variable +from framework import Variable, default_main_program __all__ = ['DataFeeder'] @@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object): class DataFeeder(object): - def __init__(self, feed_list, place): + def __init__(self, feed_list, place, program=None): self.feed_dtypes = [] self.feed_names = [] self.feed_shapes = [] self.feed_lod_level = [] + if program is None: + program = default_main_program() for each_var in feed_list: + if isinstance(each_var, basestring): + each_var = program.block(0).var(each_var) if not isinstance(each_var, Variable): raise TypeError("Feed list should contain a list of variable") self.feed_dtypes.append(each_var.dtype) diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 2c91afb363bf72f2791e60c6df0d9130ccd698c5..1d6c594b41a2c295e3818fb119362d1daba1de33 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -1,12 +1,31 @@ import numpy as np +import contextlib +from framework import Program, default_main_program from . import core -from framework import Program, default_main_program, Parameter, Variable -__all__ = ['Executor', 'g_scope'] +__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope'] g_scope = core.Scope() +def global_scope(): + return g_scope + + +def switch_scope(scope): + global g_scope + ex = g_scope + g_scope = scope + return ex + + +@contextlib.contextmanager +def scope_guard(scope): + ex = switch_scope(scope) + yield + switch_scope(ex) + + def as_numpy(tensor): if isinstance(tensor, list): return [as_numpy(t) for t in tensor] @@ -117,7 +136,7 @@ class Executor(object): raise TypeError() if scope is None: - scope = g_scope + scope = global_scope() program = program.clone() global_block = program.global_block() diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 78b8d97b2e9cb9d78dd179070c59e09734042936..926327b70c70d250766c2640d808bb3e3516d37b 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -188,7 +188,7 @@ def save_inference_model(dirname, raise ValueError("'feed_var_names' should be a list of str.") if isinstance(target_vars, Variable): - feeded_var_names = [feeded_var_names] + target_vars = [target_vars] else: if not (bool(target_vars) and all( isinstance(var, Variable) for var in target_vars)): diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index c3591a613acafb268a5bd70618cd4555450bac29..8acd470c5ed5fa8eeda396f1e9182db4ecdd7016 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -170,7 +170,7 @@ def main(): exe.run(fluid.default_startup_program()) - embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor() + embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor() embedding_param.set( load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place) diff --git a/python/paddle/v2/fluid/tests/decorators.py b/python/paddle/v2/fluid/tests/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..154619b0e93455922700a12d734967b4d20c4f13 --- /dev/null +++ b/python/paddle/v2/fluid/tests/decorators.py @@ -0,0 +1,29 @@ +import paddle.v2.fluid as fluid + +__all__ = ['many_times', 'prog_scope'] + + +def many_times(times): + def __impl__(fn): + def __fn__(*args, **kwargs): + for _ in range(times): + fn(*args, **kwargs) + + return __fn__ + + return __impl__ + + +def prog_scope(): + def __impl__(fn): + def __fn__(*args, **kwargs): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + fn(*args, **kwargs) + + return __fn__ + + return __impl__ diff --git a/python/paddle/v2/fluid/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py index a0d6655d4cbcff8ed3d55df0f4e68fc6591fbb11..7dbc2fa0858a68c5da9e8d48dcb187494357e940 100644 --- a/python/paddle/v2/fluid/tests/test_adam_op.py +++ b/python/paddle/v2/fluid/tests/test_adam_op.py @@ -1,6 +1,8 @@ import unittest import numpy as np from op_test import OpTest +from paddle.v2.fluid import core +from paddle.v2.fluid.op import Operator class TestAdamOp1(OpTest): @@ -176,5 +178,124 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out +def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + # grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = np.zeros(shape=[height, row_numel]) + moment2_out = np.zeros(shape=[height, row_numel]) + param_out = np.zeros(shape=[height, row_numel]) + + for idx, row_id in enumerate(rows): + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * np_grad[idx] + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(np_grad[idx]) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / ( + np.sqrt(moment2_out[row_id]) + epsilon)) + return param_out, moment1_out, moment2_out + + +class TestSparseAdamOp(unittest.TestCase): + def setup(self, scope, place): + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + + height = 10 + rows = [0, 4, 7] + self.rows = rows + row_numel = 12 + self.row_numel = row_numel + self.dense_inputs = { + "Param": np.full((height, row_numel), 5.0).astype("float32"), + "Moment1": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + 'Beta1Pow': np.array([beta1**10]).astype("float32"), + 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "LearningRate": np.full((1), 2.0).astype("float32") + } + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + self.sparse_inputs = ["Grad"] + + param_out, mom1, mom2 = adam_step_sparse( + self.dense_inputs, self.attrs, height, rows, row_numel, np_array) + self.outputs = { + "ParamOut": param_out, + "Moment1Out": mom1, + "Moment2Out": mom2 + } + + def check_with_place(self, place): + scope = core.Scope() + self.setup(scope, place) + + op_args = dict() + for key, np_array in self.dense_inputs.iteritems(): + var = scope.var(key).get_tensor() + var.set(np_array, place) + op_args[key] = key + for s in self.sparse_inputs: + op_args[s] = s + for s in self.outputs: + var = scope.var(s).get_tensor() + var.set(self.outputs[s], place) + op_args[s] = s + for k in self.attrs: + op_args[k] = self.attrs[k] + + # create and run sgd operator + adam_op = Operator("adam", **op_args) + adam_op.run(scope, place) + + for key, np_array in self.outputs.iteritems(): + out_var = scope.var(key).get_tensor() + actual = np.array(out_var) + actual = actual.reshape([actual.size]) + np_array = np_array.reshape([np_array.size]) + for idx, row_id in enumerate(self.rows): + j = 0 + while j < self.row_numel: + pos = row_id * self.row_numel + j + self.assertLess((actual[pos] - np_array[pos]) / actual[pos], + 0.00001) + j += 1 + + def test_sparse_sgd(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_detection_output_op.py b/python/paddle/v2/fluid/tests/test_detection_output_op.py new file mode 100644 index 0000000000000000000000000000000000000000..080a9743b0182cb7e6dd0030fc306a7f82510a05 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_output_op.py @@ -0,0 +1,57 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestUnpoolOp(OpTest): + def setUp(self): + self.op_type = "detection_output" + self.init_test_case() + + #loc.shape ((1, 4, 4, 1, 1)) + #conf.shape ((1, 4, 2, 1, 1)) + + loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]], + [[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]]) + conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]], + [[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]]) + priorbox = np.array([ + 0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1, + 0.1, 0.2, 0.2, 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2, 0.4, 0.4, + 0.8, 0.8, 0.1, 0.1, 0.2, 0.2 + ]) + + output = np.array([ + 0, 1, 0.68997443, 0.099959746, 0.099959746, 0.50804031, 0.50804031 + ]) + self.inputs = { + 'Loc': loc.astype('float32'), + 'Conf': conf.astype('float32'), + 'PriorBox': priorbox.astype('float32') + } + self.attrs = { + 'num_classes': self.num_classes, + 'top_k': self.top_k, + 'nms_top_k': self.nms_top_k, + 'background_label_id': self.background_label_id, + 'nms_threshold': self.nms_threshold, + 'confidence_threshold': self.confidence_threshold, + } + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.num_classes = 2 + self.top_k = 10 + self.nms_top_k = 20 + self.background_label_id = 0 + self.nms_threshold = 0.01 + self.confidence_threshold = 0.01 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py new file mode 100644 index 0000000000000000000000000000000000000000..c02c59284e1ca2e28ba2f6c5ec13b241c15fc288 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py @@ -0,0 +1,347 @@ +import numpy +import random +import collections +import paddle.v2.fluid as fluid +import unittest +from decorators import * + + +class Memory(object): + def __init__(self, shape, dtype='float32'): + self.ex = numpy.zeros(shape=shape, dtype=dtype) + self.cur = None + + def update(self, val): + assert val.shape == self.ex.shape + assert val.dtype == self.ex.dtype + self.cur = val + + def ex(self): + return self.ex + + def next(self): + self.ex = self.cur + self.cur = None + + def __next__(self): + self.next() + + def reset(self): + self.ex = numpy.zeros(shape=self.ex.shape, dtype=self.ex.dtype) + self.cur = None + + +class Output(object): + def __init__(self): + self.outs = [] + + def next_sequence(self): + self.outs.append([]) + + def out(self, val): + self.outs[-1].append(val) + + def last(self): + return self.outs[-1][-1] + + +class BaseRNN(object): + def __init__(self, ins, mems, params, outs, num_seq=5, max_seq_len=15): + self.num_seq = num_seq + self.inputs = collections.defaultdict(list) + + for _ in xrange(num_seq): + seq_len = random.randint(1, max_seq_len - 1) + for iname in ins: + ishape = ins[iname].get('shape', None) + idtype = ins[iname].get('dtype', 'float32') + lst = [] + for _ in xrange(seq_len): + lst.append(numpy.random.random(size=ishape).astype(idtype)) + self.inputs[iname].append(lst) + + self.mems = dict() + for mname in mems: + mshape = mems[mname].get('shape', None) + mdtype = mems[mname].get('dtype', 'float32') + self.mems[mname] = Memory(shape=mshape, dtype=mdtype) + + self.params = dict() + for pname in params: + pshape = params[pname].get('shape', None) + pdtype = params[pname].get('dtype', 'float32') + self.params[pname] = numpy.random.random(size=pshape).astype(pdtype) + + self.outputs = dict() + + for oname in outs: + self.outputs[oname] = Output() + + def step(self, **kwargs): + raise NotImplementedError() + + def exe(self): + retv = dict() + for out in self.outputs: + retv[out] = [] + + for seq_id in xrange(self.num_seq): + for mname in self.mems: + self.mems[mname].reset() + for out in self.outputs: + self.outputs[out].next_sequence() + + iname0 = self.inputs.keys()[0] + seq_len = len(self.inputs[iname0][seq_id]) + + for step_id in xrange(seq_len): + xargs = dict() + + for iname in self.inputs: + xargs[iname] = self.inputs[iname][seq_id][step_id] + + for mname in self.mems: + xargs[mname] = self.mems[mname] + + for pname in self.params: + xargs[pname] = self.params[pname] + + for out in self.outputs: + xargs[out] = self.outputs[out] + + self.step(**xargs) + + for mname in self.mems: + next(self.mems[mname]) + + for out in self.outputs: + retv[out].append(self.outputs[out].last()) + + for out in retv: + retv[out] = numpy.array(retv[out]) + return retv + + def to_feed(self, place): + feed_dict = dict() + + for iname in self.inputs: + lod = [0] + np_flatten = [] + for seq_id in xrange(len(self.inputs[iname])): + seq_len = len(self.inputs[iname][seq_id]) + lod.append(lod[-1] + seq_len) + np_flatten.extend(self.inputs[iname][seq_id]) + + t = fluid.Tensor() + t.set(numpy.array(np_flatten), place) + t.set_lod([lod]) + feed_dict[iname] = t + + for pname in self.params: + feed_dict[pname] = self.params[pname] + return feed_dict + + def get_numeric_gradient_of_param(self, param_name, delta=0.001): + p = self.params[param_name] + if len(p.shape) != 2: + raise ValueError("Not support get numeric gradient of an parameter," + " which is not matrix") + g = numpy.zeros(shape=p.shape, dtype=p.dtype) + + for i in xrange(p.shape[0]): + for j in xrange(p.shape[1]): + o = p[i][j] + p[i][j] += delta + pos = self._exe_mean_out_() + p[i][j] -= 2 * delta + neg = self._exe_mean_out_() + p[i][j] = o + g[i][j] = (pos - neg) / (delta * 2) + return g + + def get_numeric_gradient_of_input(self, + input_name, + delta=0.001, + return_one_tensor=True): + ipt = self.inputs[input_name] + grad = [] + + for seq in ipt: + seq_grad = [] + for item in seq: + item_grad = numpy.zeros(shape=item.shape, dtype=item.dtype) + if len(item.shape) != 1: + raise ValueError("Not support") + + for i in xrange(len(item)): + o = item[i] + item[i] += delta + pos = self._exe_mean_out_() + item[i] -= 2 * delta + neg = self._exe_mean_out_() + item[i] = o + item_grad[i] = (pos - neg) / (delta * 2) + seq_grad.append(item_grad) + grad.append(seq_grad) + + if not return_one_tensor: + return grad + + for i in xrange(len(grad)): + grad[i] = numpy.concatenate(grad[i]) + grad = numpy.concatenate(grad) + return grad + + def _exe_mean_out_(self): + outs = self.exe() + return numpy.array([o.mean() for o in outs.itervalues()]).mean() + + +class TestSimpleMul(unittest.TestCase): + DATA_NAME = 'X' + DATA_WIDTH = 32 + PARAM_NAME = 'W' + HIDDEN_WIDTH = 10 + OUT_NAME = 'Out' + + class SimpleMul(BaseRNN): + def __init__(self): + base = TestSimpleMul + super(base.SimpleMul, self).__init__({ + base.DATA_NAME: { + 'shape': [base.DATA_WIDTH] + } + }, {}, { + base.PARAM_NAME: { + 'shape': [base.DATA_WIDTH, base.HIDDEN_WIDTH] + } + }, [base.OUT_NAME]) + + def step(self, X, W, Out): + Out.out(numpy.matmul(X, W)) + + # Test many times in local to ensure the random seed cannot breaks CI + # @many_times(10) + @prog_scope() + def test_forward_backward(self): + py_rnn = TestSimpleMul.SimpleMul() + dat = fluid.layers.data( + name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1) + dat.stop_gradient = False + + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + d = rnn.step_input(dat) + o = fluid.layers.fc(input=d, + param_attr=self.PARAM_NAME, + bias_attr=False, + size=self.HIDDEN_WIDTH, + act=None) + rnn.output(o) + + out = rnn() + out = fluid.layers.sequence_pool(out, pool_type='last') + loss = fluid.layers.mean(x=out) + fluid.backward.append_backward(loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + out, w_g, i_g = map(numpy.array, + exe.run(feed=py_rnn.to_feed(cpu), + fetch_list=[ + out, self.PARAM_NAME + "@GRAD", + self.DATA_NAME + "@GRAD" + ], + return_numpy=False)) + out_by_python = py_rnn.exe()[self.OUT_NAME] + self.assertTrue(numpy.allclose(out, out_by_python)) + w_g_num = py_rnn.get_numeric_gradient_of_param(self.PARAM_NAME) + self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.05)) + i_g_num = py_rnn.get_numeric_gradient_of_input( + input_name=self.DATA_NAME) + i_g_num = i_g_num.reshape(i_g.shape) + self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05)) + + +class TestSimpleMulWithMemory(unittest.TestCase): + DATA_WIDTH = 32 + HIDDEN_WIDTH = 20 + DATA_NAME = 'X' + PARAM_NAME = 'W' + + class SimpleMulWithMemory(BaseRNN): + def __init__(self): + super(TestSimpleMulWithMemory.SimpleMulWithMemory, self).__init__({ + TestSimpleMulWithMemory.DATA_NAME: { + 'shape': [TestSimpleMulWithMemory.DATA_WIDTH] + } + }, {'Mem': { + 'shape': [TestSimpleMulWithMemory.HIDDEN_WIDTH] + }}, { + TestSimpleMulWithMemory.PARAM_NAME: { + 'shape': [ + TestSimpleMulWithMemory.DATA_WIDTH, + TestSimpleMulWithMemory.HIDDEN_WIDTH + ] + } + }, ['Out']) + + def step(self, X, Mem, W, Out): + o = numpy.matmul(X, W) + assert isinstance(Mem, Memory) + o += Mem.ex + Mem.update(o) + assert isinstance(Out, Output) + Out.out(o) + + # many_times used locally for debug. Make sure the calculation is stable. + # @many_times(10) + @prog_scope() + def test_forward_backward(self): + py_rnn = TestSimpleMulWithMemory.SimpleMulWithMemory() + data = fluid.layers.data( + name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1) + data.stop_gradient = False + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + d = rnn.step_input(data) + mem = rnn.memory(value=0.0, shape=[self.HIDDEN_WIDTH]) + hidden = fluid.layers.fc(input=d, + size=self.HIDDEN_WIDTH, + param_attr=self.PARAM_NAME, + bias_attr=False, + act=None) + o = fluid.layers.elementwise_add(x=hidden, y=mem) + rnn.update_memory(mem, o) + rnn.output(o) + + out = rnn() + last = fluid.layers.sequence_pool(input=out, pool_type='last') + loss = fluid.layers.mean(x=last) + fluid.backward.append_backward(loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + feed = py_rnn.to_feed(cpu) + last_np, w_g, i_g = map(numpy.array, + exe.run(feed=feed, + fetch_list=[ + last, self.PARAM_NAME + "@GRAD", + self.DATA_NAME + "@GRAD" + ], + return_numpy=False)) + last_by_py, = py_rnn.exe().values() + w_g_num = py_rnn.get_numeric_gradient_of_param(self.PARAM_NAME) + self.assertTrue(numpy.allclose(last_np, last_by_py)) + + self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.1)) + i_g_num = py_rnn.get_numeric_gradient_of_input(self.DATA_NAME) + i_g_num = i_g_num.reshape(i_g.shape) + + # Since this RNN has many float add. The number could be not stable. + # rtol = 0.1 + self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.1)) + + +if __name__ == '__main__': + unittest.main()