diff --git a/README.md b/README.md index 68421cf177f4cd15f8f44e8d00a27cafb5a13b91..5c428e9900762a208eebbfd053ce98663f803345 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ English | [简体中文](./README_cn.md) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) -[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) +[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) @@ -18,7 +18,7 @@ learning to many products at Baidu. Our vision is to enable deep learning for everyone via PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. -### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2) +### Latest PaddlePaddle Release: [Fluid 1.3.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.3) ### Install Latest Stable Release: ``` # Linux CPU @@ -26,9 +26,9 @@ pip install paddlepaddle # Linux GPU cuda9cudnn7 pip install paddlepaddle-gpu # Linux GPU cuda8cudnn7 -pip install paddlepaddle-gpu==1.2.0.post87 +pip install paddlepaddle-gpu==1.3.0.post87 # Linux GPU cuda8cudnn5 -pip install paddlepaddle-gpu==1.2.0.post85 +pip install paddlepaddle-gpu==1.3.0.post85 # For installation on other platform, refer to http://paddlepaddle.org/ ``` @@ -75,26 +75,26 @@ pip install paddlepaddle-gpu==1.2.0.post85 ## Installation -It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website. +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) on our website. ## Documentation -We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and -[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) documentation. - [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/en/1.3/user_guides/howto/training/multi_node_en.html) You can run distributed training jobs on MPI clusters. -- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html) +- [Python API](http://paddlepaddle.org/documentation/docs/en/1.3/api/index_en.html) Our new API enables much shorter programs. -- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/en/1.3/advanced_usage/development/contribute_to_paddle/index_en.html) We appreciate your contributions! diff --git a/README_cn.md b/README_cn.md index dfb55b17ca4fd05ce5b7b85b2e26e4f7f7229763..b7b0e75e5524cc483a8c203a382e7f339f91694f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -3,8 +3,8 @@ [English](./README.md) | 简体中文 [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) -[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) +[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) @@ -16,7 +16,7 @@ PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效 跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases) -### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2) +### PaddlePaddle最新版本: [Fluid 1.3.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.3) ### 安装最新稳定版本: ``` # Linux CPU @@ -24,9 +24,9 @@ pip install paddlepaddle # Linux GPU cuda9cudnn7 pip install paddlepaddle-gpu # Linux GPU cuda8cudnn7 -pip install paddlepaddle-gpu==1.2.0.post87 +pip install paddlepaddle-gpu==1.3.0.post87 # Linux GPU cuda8cudnn5 -pip install paddlepaddle-gpu==1.2.0.post85 +pip install paddlepaddle-gpu==1.3.0.post85 # 其他平台上的安装指引请参考 http://paddlepaddle.org/ ``` @@ -57,26 +57,26 @@ pip install paddlepaddle-gpu==1.2.0.post85 ## 安装 -推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) +推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html) ## 文档 -我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)和 -[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档 +我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html)和 +[中文](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) 文档 - [深度学习101](https://github.com/PaddlePaddle/book) 或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行 -- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html) +- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.3/user_guides/howto/training/multi_node.html) 可以在MPI集群上运行分布式训练任务 -- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html) +- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.3/api_cn/index_cn.html) 新的API支持代码更少更简洁的程序 -- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html) +- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.3/advanced_usage/development/contribute_to_paddle/index_cn.html) 欢迎您的贡献! diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 32a9368a9f639468a2144548118677aaddecbdf2..54826cedb871690a82b535ae3ed102600277c622 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -39,10 +39,8 @@ IF(WIN32) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) -ELSE() - #TODO(intel-huying): - # Now enable Erf function in mklml library temporarily, it will be updated as offical version later. - SET(MKLML_VER "VsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) +ELSE() + SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f24cf96cce30bf4b21e954ab022b2b99aeff37e6..62c96f8f5fe1ac550aa9f4e4fbb1d81af8f5b3be 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)) paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')) paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)) -paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None)) +paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)) paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)) paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)) paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) @@ -121,6 +121,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) +paddle.fluid.layers.sampled_softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 8c6c9f35e84f4fd7d2b5486ac0eb60beceb512a2..17dd1399119d190bcbc31adb34ec61deb92a9994 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -135,12 +135,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { void AppendMultiDevPass(const BuildStrategy &strategy) { ir::Pass *multi_devices_pass; if (strategy_.is_distribution_) { + VLOG(3) << "multi device parameter server mode"; multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); } else { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + VLOG(3) << "multi devices collective mode with allreduce"; multi_devices_pass = AppendPass("allreduce_mode_multi_devices_pass").get(); } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + VLOG(3) << "multi deivces collective mode with reduce"; multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); } else { PADDLE_THROW("Unknown reduce strategy."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 7d1e63f3682bca8965f6c5e695132dff44fa3715..478d2ffbcf2988487893984284d4597f018f0ca0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -937,9 +937,21 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, } void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { - if (need_broadcast_var_ || - (UseGPU() && - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) { + // broad cast received parameters when training in parameter server mode. + if (need_broadcast_var_) { + // There are 4 conditions: + // 1. GPU && Reduce: Reduce gradient then broadcast gradient to other GPUS. + // Need to broadcast received parameters to other GPU. + // 2. GPU && AllReduce: AllReduce all graident to each GPU. Need to + // broadcast received parameters to other GPU. + // 3. CPU && AllReduce: AllReduce all gradient to each thread. Need to + // broadcast received parameters to other scope. + // 4. CPU && Reduce: because all parameters share the same memory, did not + // broadcast received parameters. + if (!UseGPU() && + strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + return; + } if (strategy_.fuse_broadcast_op_) { CreateFusedBroadcastOp(result, bcast_var_name_set_); } else { diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2166b8b545c23025e029d283b7ca43719e31a259..a3f2a69aef52b6f55aa09e6dee2c22c048626c0d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index e8f5530b7887c19679dd1071805c5a66602493ba..c7df3ea58a91579e35ff0d486516271a6daf054f 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -11,7 +11,6 @@ limitations under the License. */ #pragma once #include -#include #include #include #include @@ -25,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" #ifdef PADDLE_WITH_MKLDNN @@ -303,28 +301,8 @@ template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { -// Because the execute or device context can not be deliver here, it keep the -// marco for NVCC. -#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) - auto x_data = x.data(); - auto out_data = out.data(); - int n = std::min(x.size(), out.size()); - - std::memset(out_data, 0, n * sizeof(T)); - math::CBlas::AXPY(n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); - math::CBlas::VMERF(n, out_data, out_data, VML_LA); - for (int i = 0; i < n; i++) { - out_data[i] += static_cast(1); - } - math::CBlas::VMUL(n, x_data, out_data, out_data); - for (int i = 0; i < n; i++) { - out_data[i] *= static_cast(0.5); - } -#else auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); -#endif } }; diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h index 6aefc5446f167eebb0da673b3fbdf7ed128daa98..0b883c3158fb922caae2e731875bbb8d43a1e9ca 100644 --- a/paddle/fluid/operators/beam_search_decode_op.h +++ b/paddle/fluid/operators/beam_search_decode_op.h @@ -122,7 +122,7 @@ void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( auto cpu_place = std::unique_ptr( new paddle::platform::CPUPlace()); - paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get()); + paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); framework::LoD lod; lod.push_back(source_level_lod); diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 2a69ad4b53c26f5e2e0547e75e0d9c6518a8bcba..ab01bdf7ca8c5a369bd8838b1acc734364666992 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -144,34 +144,40 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); AddComment(R"DOC( - This operator generate yolov3 loss by given predict result and ground + This operator generates yolov3 loss based on given predict result and ground truth boxes. The output of previous network is in shape [N, C, H, W], while H and W - should be the same, specify the grid size, each grid point predict given - number boxes, this given number is specified by anchors, it should be - half anchors length, which following will be represented as S. In the - second dimention(the channel dimention), C should be S * (class_num + 5), - class_num is the box categoriy number of source dataset(such as coco), - so in the second dimention, stores 4 box location coordinates x, y, w, h - and confidence score of the box and class one-hot key of each anchor box. + should be the same, H and W specify the grid size, each grid point predict + given number boxes, this given number, which following will be represented as S, + is specified by the number of anchors, In the second dimension(the channel + dimension), C should be equal to S * (class_num + 5), class_num is the object + category number of source dataset(such as 80 in coco dataset), so in the + second(channel) dimension, apart from 4 box location coordinates x, y, w, h, + also includes confidence score of the box and class one-hot key of each anchor box. - While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions - correspnd to: + Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box predictions + should be as follows: $$ - b_x = \sigma(t_x) + c_x - b_y = \sigma(t_y) + c_y + b_x = \\sigma(t_x) + c_x + $$ + $$ + b_y = \\sigma(t_y) + c_y + $$ + $$ b_w = p_w e^{t_w} + $$ + $$ b_h = p_h e^{t_h} $$ - While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ - is specified by anchors. + In the equation above, :math:`c_x, c_y` is the left top corner of current grid + and :math:`p_w, p_h` is specified by anchors. As for confidence score, it is the logistic regression value of IoU between anchor boxes and ground truth boxes, the score of the anchor box which has - the max IoU should be 1, and if the anchor box has IoU bigger then ignore + the max IoU should be 1, and if the anchor box has IoU bigger than ignore thresh, the confidence score loss of this anchor box will be ignored. Therefore, the yolov3 loss consist of three major parts, box location loss, @@ -186,13 +192,13 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { In order to trade off box coordinate losses between big boxes and small boxes, box coordinate losses will be mutiplied by scale weight, which is - calculated as follow. + calculated as follows. $$ weight_{box} = 2.0 - t_w * t_h $$ - Final loss will be represented as follow. + Final loss will be represented as follows. $$ loss = (loss_{xy} + loss_{wh}) * weight_{box} diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index 3f110024b285d41ccfe305e35c8efca5ed5ee0fe..ca998826dd0118ab4b1ecc23bed8ef882f1bcc92 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel { lstm_value.output_value = out_t.data(); lstm_value.state_value = cell_t.data(); lstm_value.state_active_value = cell_pre_act_t.data(); + T cell_clip = 0.0; math::LstmUnitFunctor::compute( - device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, - cell_act, cand_act); + device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip, + gate_act, cell_act, cand_act); lstm_value.prev_state_value = lstm_value.state_value; } @@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel { lstm_value.output_value = nullptr; lstm_grad.state_active_grad = nullptr; int cur_batch_size = bend - bstart; + T cell_clip = 0.0; math::LstmUnitGradFunctor::compute( device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, - gate_act, cell_act, cand_act); + cell_clip, gate_act, cell_act, cand_act); if (n > 0) { int pre_h_start = static_cast(batch_starts[n - 1]); diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index 7a62bc9f828e4d3485628747cdf52c60c5354144..2728aa8a4ee21a9e1fe3deddcdba4c35a6aba7bc 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("C0"), "Input(C0) of LSTMP operator should not be null after " "Input(H0) provided."); - auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); - ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]}); } auto b_dims = ctx->GetInputDim("Bias"); @@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "This LoDTensor is obtained in the forward and used in the " "backward.") .AsIntermediate(); - AddOutput("OrderedP0", - "(Tensor) the projection of the initial hidden state " - "H0. This is a tensor with shape (N x P), where N is the " - "batch size and P is the hidden size.") - .AsIntermediate(); AddAttr("use_peepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed LSTMP.") .SetDefault(false); + AddAttr("cell_clip", + "(float, defalut: 0.0) " + "Clip for Tensor for cell state tensor when clip value is " + "greater than 0.0") + .SetDefault(0.0); + AddAttr("proj_clip", + "(float, defalut: 0.0) " + "Clip for Tensor for projection tensor when clip value is " + "greater than 0.0") + .SetDefault(0.0); AddAttr( "gate_activation", "(string, default: sigmoid)" diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index 1f11e57dcb721012c7b8e50d7e138355685053da..c7d6e4205f8862526904e4fa767a2f4c4a2d8481 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" @@ -21,17 +22,50 @@ limitations under the License. */ #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +using platform::Transform; template using EigenMatrix = framework::EigenMatrix; +template +class _ClipFunctor { + public: + explicit _ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x) const { + if (x < min_) + return min_; + else if (x > max_) + return max_; + else + return x; + } + + private: + T min_; + T max_; +}; + +template +class _ClipGradFunctor { + public: + explicit _ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x, const T& y) const { + return (y > min_ && y < max_) ? x : 0; + } + + private: + T min_; + T max_; +}; + template inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, @@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel { auto* bias = ctx.Input("Bias"); auto* hidden_t0 = ctx.Input("H0"); - auto* ordered_proj0 = ctx.Output("OrderedP0"); auto* cell_t0 = ctx.Input("C0"); + auto proj_clip = static_cast(ctx.Attr("proj_clip")); + auto cell_clip = static_cast(ctx.Attr("cell_clip")); + auto* batch_gate = ctx.Output("BatchGate"); batch_gate->mutable_data(ctx.GetPlace()); auto* proj_out = ctx.Output("Projection"); @@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel { } lstmp_value.prev_state_value = nullptr; Tensor ordered_c0; + Tensor ordered_h0; framework::Vector order(batch_gate->lod()[2]); @@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel { // Since the batch computing for LSTMP reorders the input sequence // according to their length. The initialized hidden state also needs // to reorder. - - Tensor ordered_h0; - ordered_proj0->mutable_data(ctx.GetPlace()); ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, true); - blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast(1.0), - ordered_proj0, static_cast(0.0)); - if (proj_act != math::detail::ActivationType::kIdentity) { - auto proj0_dev = EigenMatrix::From(*ordered_proj0); - ActCompute(cell_act, place, proj0_dev, proj0_dev); - } - blas.MatMul(*ordered_proj0, false, *weight, false, static_cast(1.0), + blas.MatMul(ordered_h0, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); } @@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel { lstmp_value.state_value = cell_t.data(); lstmp_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute( - device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, - cell_act, cand_act); + device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip, + gate_act, cell_act, cand_act); lstmp_value.prev_state_value = lstmp_value.state_value; blas.MatMul(hidden_t, false, *proj_weight, false, static_cast(1.0), &proj_t, static_cast(0.0)); @@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel { auto proj_t_dev = EigenMatrix::From(proj_t); ActCompute(cell_act, place, proj_t_dev, proj_t_dev); } + if (proj_clip && proj_clip > 0.0) { + T* x_data = proj_t.data(); + int64_t numel = proj_t.numel(); + Transform trans; + trans(ctx.template device_context(), x_data, + x_data + numel, x_data, + _ClipFunctor(-1.0 * proj_clip, proj_clip)); + } } math::Batch2LoDTensorFunctor to_seq; @@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel { auto* proj_out = ctx.Input("Projection"); auto* cell_out = ctx.Input("Cell"); + auto proj_clip = static_cast(ctx.Attr("proj_clip")); + auto cell_clip = static_cast(ctx.Attr("cell_clip")); + auto* batch_gate = ctx.Input("BatchGate"); auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); auto* batch_hidden = ctx.Input("BatchHidden"); @@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel { auto* bias_g = ctx.Output(framework::GradVarName("Bias")); auto* h0 = ctx.Input("H0"); - auto* ordered_proj0 = ctx.Input("OrderedP0"); auto* c0 = ctx.Input("C0"); auto* h0_g = ctx.Output(framework::GradVarName("H0")); @@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel { Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend); + + if (proj_clip && proj_clip > 0.0) { + T* dx_data = proj_g.data(); + T* x_data = cur_proj.data(); + int64_t numel = proj_g.numel(); + Transform trans; + trans(ctx.template device_context(), dx_data, + dx_data + numel, x_data, dx_data, + _ClipGradFunctor(-1.0 * proj_clip, proj_clip)); + } + if (proj_act != math::detail::ActivationType::kIdentity) { auto cur_proj_dev = EigenMatrix::From(cur_proj); auto proj_g_dev = EigenMatrix::From(proj_g); @@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel { math::LstmUnitGradFunctor::compute( device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, - gate_act, cell_act, cand_act); + cell_clip, gate_act, cell_act, cand_act); if (n > 0) { int pre_h_start = static_cast(batch_starts[n - 1]); @@ -431,31 +480,14 @@ class LSTMPGradKernel : public framework::OpKernel { ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); if (weight_g) { - blas.MatMul(*ordered_proj0, true, gate_g, false, - static_cast(1.0), weight_g, static_cast(1.0)); + blas.MatMul(ordered_h0, true, gate_g, false, static_cast(1.0), + weight_g, static_cast(1.0)); } } if (h0 && (h0_g || proj_weight_g)) { ordered_h0_g.mutable_data(h0_g->dims(), ctx.GetPlace()); - Tensor proj0_g; - proj0_g.Resize({in_dims[0], proj_weight->dims()[1]}); - proj0_g.mutable_data(ctx.GetPlace()); blas.MatMul(gate_g, false, *weight, true, static_cast(1.0), - &proj0_g, static_cast(0.0)); - if (proj_act != math::detail::ActivationType::kIdentity) { - auto proj0_dev = EigenMatrix::From(*ordered_proj0); - auto proj0_g_dev = EigenMatrix::From(proj0_g); - ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, - proj0_g_dev); - } - if (h0_g) { - blas.MatMul(proj0_g, false, *proj_weight, true, static_cast(1.0), - &ordered_h0_g, static_cast(0.0)); - } - if (proj_weight_g) { - blas.MatMul(ordered_h0, true, proj0_g, false, static_cast(1.0), - proj_weight_g, static_cast(1.0)); - } + &ordered_h0_g, static_cast(0.0)); } } } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 4b6eef18d8b967af5f3a5df0dee750620e7e412a..d4837696241b8c4e3cca4f2afe872c6be559853c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -39,6 +39,7 @@ math_library(cross_entropy) math_library(cos_sim_functor) math_library(depthwise_conv DEPS cub) math_library(im2col) +math_library(sample_prob) math_library(sampler) math_library(gru_compute DEPS activation_functions math_function) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index ce8109f64d62b0d412419107881952f1b4ffc75e..f67f57827bc03e134bf87edd5bf033adb5098916 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -184,9 +184,6 @@ class Blas { template void VINV(int n, const T* a, T* y) const; - template - void VMERF(int n, const T* a, T* y, int64_t mode) const; - private: const DeviceContext& context_; }; @@ -293,11 +290,6 @@ class BlasT : private Blas { Base()->template VINV(args...); } - template - void VMERF(ARGS... args) const { - Base()->template VMERF(args...); - } - private: const Blas* Base() const { return static_cast*>(this); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index ba995dabecbfab8c4952bb7efeaa381f8078821a..972366bc093f4b7f0a090cf31213f75ccd89fd82 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -123,11 +123,6 @@ struct CBlas { static void VINV(ARGS... args) { platform::dynload::vsInv(args...); } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmsErf(args...); - } }; template <> @@ -228,11 +223,6 @@ struct CBlas { static void VINV(ARGS... args) { platform::dynload::vdInv(args...); } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmdErf(args...); - } }; #else @@ -635,19 +625,6 @@ void Blas::VINV(int n, const T *a, T *y) const { #endif } -template <> -template -void Blas::VMERF(int n, const T *a, T *y, - int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h index 2e3779ff0845294e71f27801049c010e0a585e6b..ad79c58063a8a12c703979fe32a8e671a5ade857 100644 --- a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h @@ -32,7 +32,8 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { T r_value_in; @@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - ActivationType active_node, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { T r_value_in; @@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { #ifdef __AVX__ @@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - ActivationType active_node, + T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { #ifdef __AVX__ @@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, template void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + T cell_clip, ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frame_size, active_node, - active_gate, active_state); + avx_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frame_size, active_node, - active_gate, active_state); + naive_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state); } } template void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, ActivationType active_node, + int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frame_size, active_node, - active_gate, active_state); + avx_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, + active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frame_size, + naive_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, active_node, active_gate, active_state); } } diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index 2aecb69237fdf344ebc0bfe72d9c7c147f06358d..e0ca9e7f5b2f4a8bb837768d645b5103aa3e6760 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -31,7 +31,8 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, - int batch_size, ActivationType active_node, + int batch_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, - active_node, active_gate, active_state); + &cell_clip, active_node, active_gate, active_state); value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx + frame_size] = r_value_ig; @@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, - int batch_size, ActivationType active_node, + int batch_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, ActivationType active_state) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, - &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node, - active_gate, active_state); + &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip, + active_node, active_gate, active_state); grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx + frame_size] = r_grad_ig; @@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, template void gpu_lstm_forward(const platform::DeviceContext& context, Op op, LstmMetaValue value, int frame_size, int batch_size, - ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + T cell_clip, ActivationType active_node, + ActivationType active_gate, ActivationType active_state) { dim3 threads; dim3 grid; if (batch_size == 1) { @@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batch_size == 1) { KeLstmForward<<>>( - op, value, frame_size, batch_size, active_node, active_gate, + op, value, frame_size, batch_size, cell_clip, active_node, active_gate, active_state); } else { KeLstmForward<<>>( - op, value, frame_size, batch_size, active_node, active_gate, + op, value, frame_size, batch_size, cell_clip, active_node, active_gate, active_state); } } @@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, template void gpu_lstm_backward(const platform::DeviceContext& context, Op op, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { dim3 threads; @@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batch_size == 1) { KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, active_node, active_gate, - active_state); + op, value, grad, frame_size, batch_size, cell_clip, active_node, + active_gate, active_state); } else { KeLstmBackward<<>>( - op, value, grad, frame_size, batch_size, active_node, active_gate, - active_state); + op, value, grad, frame_size, batch_size, cell_clip, active_node, + active_gate, active_state); } } diff --git a/paddle/fluid/operators/math/detail/lstm_kernel.h b/paddle/fluid/operators/math/detail/lstm_kernel.h index cbe73d62938d7c4c03a2c8731665260624417fd7..8149686c97a030b91e0c4de708b9abf07f83203d 100644 --- a/paddle/fluid/operators/math/detail/lstm_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_kernel.h @@ -29,7 +29,7 @@ class lstm { public: HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, T *prev_state, T *state, T *state_atv, T *output, - T *checkI, T *checkF, T *checkO, + T *checkI, T *checkF, T *checkO, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -37,6 +37,15 @@ class lstm { *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); + + if (*cell_clip > 0.0) { + if (*state < -1.0 * (*cell_clip)) { + *state = -1.0 * (*cell_clip); + } + if (*state > *cell_clip) { + *state = *cell_clip; + } + } *value_og = activation(*value_og + (*state) * (*checkO), active_gate); *state_atv = activation(*state, active_state); *output = (*value_og) * (*state_atv); @@ -52,7 +61,7 @@ class lstm { __m256 *value_fg, __m256 *value_og, __m256 *prev_state, __m256 *state, __m256 *state_atv, __m256 *output, __m256 *checkI, - __m256 *checkF, __m256 *checkO, + __m256 *checkF, __m256 *checkO, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { @@ -65,6 +74,13 @@ class lstm { active_gate); *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), _mm256_mul_ps(*prev_state, *value_fg)); + + if (*cell_clip > 0.0f) { + __m256 min = _mm256_set1_ps(0.0f - *cell_clip); + __m256 max = _mm256_set1_ps(*cell_clip); + *state = _mm256_min_ps(max, *state); + *state = _mm256_max_ps(min, *state); + } *value_og = activation( _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate); *state_atv = activation(*state, active_state); @@ -86,15 +102,26 @@ class lstm { T *prev_state, T *prev_state_grad, T *state, T *state_grad, T *state_atv, T *output_grad, T *checkI, T *checkF, T *checkO, T *checkIGrad, - T *checkFGrad, T *checkOGrad, + T *checkFGrad, T *checkOGrad, T *cell_clip, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { *grad_og = activation((*output_grad) * (*state_atv), *value_og, active_gate); - *state_grad += - activation((*output_grad) * (*value_og), *state_atv, active_state) + - (*grad_og) * (*checkO); + if (*cell_clip > 0.0f) { + if (*state >= (*cell_clip) || *state <= (0.0f - (*cell_clip))) { + *state_grad = 0.0f; + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + } + *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); *grad_fg = @@ -117,15 +144,24 @@ class lstm { __m256 *prev_state, __m256 *prev_state_grad, __m256 *state, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, - __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node, - ActivationType active_gate, ActivationType active_state) { + __m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip, + ActivationType active_node, ActivationType active_gate, + ActivationType active_state) { *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, active_gate); - *state_grad = - _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), - *state_atv, active_state), - *state_grad); - *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + if (*cell_clip > 0.0f) { + T *state_ = reinterpret_cast(state); + if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) { + *state_grad = _mm256_set1_ps(0.0f); + } else { + *state_grad = + _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), + *state_atv, active_state), + *state_grad); + *state_grad = + _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + } + } *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, active_node); *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, diff --git a/paddle/fluid/operators/math/lstm_compute.cc b/paddle/fluid/operators/math/lstm_compute.cc index b6882b4fd8e6db8592a282410888d5625bae742a..94bbcbb50670d9f0b11b77cf6a54a99c227521bf 100644 --- a/paddle/fluid/operators/math/lstm_compute.cc +++ b/paddle/fluid/operators/math/lstm_compute.cc @@ -24,12 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const detail::ActivationType& gate_act, + T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, - cand_act, gate_act, cell_act); + cell_clip, cand_act, gate_act, cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; value.state_active_value += frame_size; @@ -45,13 +45,14 @@ template struct LstmUnitGradFunctor { static void compute(const platform::CPUDeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { for (int b = 0; b < batch_size; b++) { detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, - frame_size, cand_act, gate_act, cell_act); + frame_size, cell_clip, cand_act, gate_act, + cell_act); value.gate_value += frame_size * 4; value.state_value += frame_size; diff --git a/paddle/fluid/operators/math/lstm_compute.cu b/paddle/fluid/operators/math/lstm_compute.cu index 1233000083d6efc31fcbc527e8e9efb83224b4e3..e7445d3d40ae92ff66e7d33a38bfdebfc8455f0a 100644 --- a/paddle/fluid/operators/math/lstm_compute.cu +++ b/paddle/fluid/operators/math/lstm_compute.cu @@ -24,12 +24,12 @@ template struct LstmUnitFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, int frame_size, int batch_size, - const detail::ActivationType& gate_act, + T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, - frame_size, batch_size, cand_act, gate_act, - cell_act); + frame_size, batch_size, cell_clip, cand_act, + gate_act, cell_act); } }; @@ -37,13 +37,13 @@ template struct LstmUnitGradFunctor { static void compute(const platform::CUDADeviceContext& context, LstmMetaValue value, LstmMetaGrad grad, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, const detail::ActivationType& cand_act) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, - frame_size, batch_size, cand_act, gate_act, - cell_act); + frame_size, batch_size, cell_clip, cand_act, + gate_act, cell_act); } }; diff --git a/paddle/fluid/operators/math/lstm_compute.h b/paddle/fluid/operators/math/lstm_compute.h index ca2f78e6f318ce39bd2272bbce20f6a6f98fe430..80af5639387aaf6a983365e13c3478353c27a617 100644 --- a/paddle/fluid/operators/math/lstm_compute.h +++ b/paddle/fluid/operators/math/lstm_compute.h @@ -50,7 +50,7 @@ template class LstmUnitFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, - int frame_size, int batch_size, + int frame_size, int batch_size, T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, const detail::ActivationType &cand_act); @@ -61,7 +61,7 @@ class LstmUnitGradFunctor { public: static void compute(const DeviceContext &context, LstmMetaValue value, LstmMetaGrad grad, int frame_size, int batch_size, - const detail::ActivationType &gate_act, + T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, const detail::ActivationType &cand_act); }; diff --git a/paddle/fluid/operators/math/sample_prob.cc b/paddle/fluid/operators/math/sample_prob.cc new file mode 100644 index 0000000000000000000000000000000000000000..99aa318453eae161807353198a78e11085cd6237 --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/sample_prob.h" + +namespace paddle { +namespace operators { +namespace math { + +template class SampleWithProb; +template class SampleWithProb; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.cu b/paddle/fluid/operators/math/sample_prob.cu new file mode 100644 index 0000000000000000000000000000000000000000..8f9391591560cc3f76ac67f43121c4b1cff90e12 --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.cu @@ -0,0 +1,161 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/sampler.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; + +template +__device__ T gpu_adjust_prob(const T prob, const int num_samples, + const int num_tries) { + if (num_samples == num_tries) { + return prob * num_samples; + } else { + return -expm1(num_tries * log1p(-prob)); + } +} + +class GPULogUniformSampler { + public: + __device__ int64_t Sample(float random, const int range, + const float log_range) const; + __device__ float Probability(int64_t value, const float log_range) const; +}; + +__device__ int64_t GPULogUniformSampler::Sample(float random, const int range, + const float log_range) const { + // Got Log Uniform distribution from uniform distribution by + // inverse_transform_sampling method + const int64_t value = static_cast(exp(random * log_range)) - 1; + // Mathematically, value should be <= range_, but might not be due to some + // floating point roundoff, so we mod by range_. + return value % range; +} + +__device__ float GPULogUniformSampler::Probability( + int64_t value, const float log_range) const { + // Given f(x) = 1/[(x+1) * log_range_] + // The value's probability is integral of f(x) from value to (value + 1) + return (log((value + 2.0) / (value + 1.0))) / log_range; +} + +template +__global__ void SamplingCondidate( + const size_t n, const int num_tries, const int range, const float log_range, + const int num_true, const std::size_t num_samples, + const int64_t* label_data, int64_t* samples_data, T* probabilities_data) { + const int num_sampled_classes = num_true + num_samples; + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = 0; + GPULogUniformSampler sampler; + + for (; idx < n; idx += blockDim.x * gridDim.x) { + int col_idx = idx % num_sampled_classes; + int row_idx = idx / num_sampled_classes; + if (col_idx < num_true) { + samples_data[idx] = label_data[row_idx * num_true + col_idx]; + } else { + samples_data[idx] = samples_data[col_idx]; + } + probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range); + probabilities_data[idx] = + gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries); + } +} + +template +int UniqSampler(const Sampler& sampler, const std::size_t num_samples, + int64_t* samples_data) { + // sample num_samles unique samples for an example, note that they are not + // all negative samples + std::unordered_set tmp_samples; + tmp_samples.clear(); + int num_tries = 0; + int j = 0; + while (j < num_samples) { + ++num_tries; + auto v = sampler.Sample(); + auto insert_ok = tmp_samples.insert(v).second; + if (!insert_ok) { + continue; + } + samples_data[j] = v; + ++j; + } + return num_tries; +} + +template +void GPUSampleWithProb::operator()( + const platform::CUDADeviceContext& context, const int seed, + const int dict_size, const bool uniq, const std::size_t num_samples, + const Tensor* L, Tensor* S, Tensor* P) { + // UNDERSTAND: dimension issues + const auto lbl_dim = L->dims(); + const int batch_size = lbl_dim[0]; + const int num_true = lbl_dim[1]; + const int num_sampled_classes = num_true + num_samples; + framework::DDim ret_dim{batch_size, num_sampled_classes}; + + // UNDERSTAND: raw data view + const int64_t* label_data = L->data(); + int64_t* samples_data = S->data(); + T* probabilities_data = P->data(); + + int s_size = num_samples; + framework::DDim s_dim{s_size}; + Tensor s; + int64_t* s_data = s.mutable_data(s_dim, platform::CPUPlace()); + + math::LogUniformSampler sampler(dict_size, seed); + + int range = dict_size; + float log_range = log(range + 1); + + int num_tries = UniqSampler(sampler, num_samples, s_data); + VLOG(1) << "num_tries: " << num_tries; + PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data, + sizeof(int64_t) * num_samples, + cudaMemcpyHostToDevice)); + + int threads = 512; + const size_t size = batch_size * num_sampled_classes; + int grid = (batch_size * num_sampled_classes + threads - 1) / threads; + SamplingCondidate<<>>( + size, num_tries, range, log_range, num_true, num_samples, label_data, + samples_data, probabilities_data); +} + +template class GPUSampleWithProb; +template class GPUSampleWithProb; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.h b/paddle/fluid/operators/math/sample_prob.h new file mode 100644 index 0000000000000000000000000000000000000000..e5a6d84cb2b0527c606e62a19ef02d669945ecb1 --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/sampler.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; + +/* UNDERSTAND: utility function to adjust probability for unique sampling, +return whatever as it is if not using unique samping */ +template +static T adjust_prob(const T prob, const int num_samples, const int num_tries) { + if (num_samples == num_tries) { + return prob * num_samples; + } else { + return -expm1(num_tries * log1p(-prob)); + } +} + +template +class SampleWithProb { + public: + void operator()(const DeviceContext& context, const Sampler& sampler, + const std::size_t num_samples, const Tensor* L, Tensor* S, + Tensor* P) { + // UNDERSTAND: dimension issues + const auto lbl_dim = L->dims(); + const int batch_size = lbl_dim[0]; + const int num_true = lbl_dim[1]; + const int num_sampled_classes = num_true + num_samples; + framework::DDim ret_dim{batch_size, num_sampled_classes}; + + // UNDERSTAND: raw data view + const int64_t* label_data = L->data(); + int64_t* samples_data = + S->mutable_data(ret_dim, context.GetPlace()); + T* probabilities_data = P->mutable_data(ret_dim, context.GetPlace()); + + // temp sets for unique sampling + std::unordered_set tmp_samples; + int j = 0; // column index + // add true labels, not that efficient + while (j < num_true) { + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + j; + auto v = label_data[i * num_true + j]; + samples_data[samples_index] = v; + probabilities_data[samples_index] = sampler.Probability(v); + } + ++j; + } + + // sample num_samles unique samples for an example, note that they are not + // all negative samples + tmp_samples.clear(); + int num_tries = 0; + while (j < num_sampled_classes) { + ++num_tries; + auto v = sampler.Sample(); + auto insert_ok = tmp_samples.insert(v).second; + if (!insert_ok) { + continue; + } + auto p = sampler.Probability(v); + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + j; + samples_data[samples_index] = v; + probabilities_data[samples_index] = p; + } + ++j; + } + + // compute Q(y|x), because of unique sampling, probabilities need to be + // adjusted + for (int k = 0; k < num_sampled_classes; ++k) { + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + k; + probabilities_data[samples_index] = adjust_prob( + probabilities_data[samples_index], num_samples, num_tries); + } + } + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class GPUSampleWithProb { + public: + void operator()(const platform::CUDADeviceContext& context, const int seed, + const int dict_size, const bool uniq, + const std::size_t num_samples, const Tensor* L, Tensor* S, + Tensor* P); +}; +#endif +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 223adcaa6b36e85ea54004c850ba6cfd142eac37..5b7505f3c4acdef94fead04efd00b47825274117 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -225,7 +225,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); PADDLE_ENFORCE(src_memory != nullptr, "Fail to find src_memory in device context"); - src_memory->set_data_handle(*p_src_data.get()); + src_memory->set_data_handle(*p_src_data); std::shared_ptr diff_src_memory; diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index f4bad7b712b2b078ed68f0a3d0e751d9ae2d6191..38a65b50bd22354bea54819e8e71015202e96e9f 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -198,7 +198,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { } // push primitive to stream and wait until it's executed - std::vector pipeline{*(pool_p.get())}; + std::vector pipeline{*pool_p}; stream(stream::kind::eager).submit(pipeline).wait(); output->set_layout(DataLayout::kMKLDNN); @@ -367,8 +367,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory); pool_bwd_p = std::make_shared( - pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory, - *(diff_src_memory)); + pool_bwd_pd, *diff_dst_memory, *workspace_memory, *diff_src_memory); dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); } else { @@ -404,7 +403,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { if (is_diff_dst_reordered) { pipeline.push_back(reorder_diff_dst); } - pipeline.push_back(*(pool_bwd_p.get())); + pipeline.push_back(*pool_bwd_p); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); in_x_grad->set_layout(DataLayout::kMKLDNN); diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index d2b149535426d097fea4b8fffa9efe82bd6edc64..dc1176f0848b93dd6872f676c3a71dab4f3455fd 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -66,8 +66,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { "Fail to find softmax primitive in device context"); if (softmax_p == nullptr) { softmax_p = std::make_shared( - *(softmax_pd_.get()), - *(static_cast(src_memory_p.get())), + *softmax_pd_, *(static_cast(src_memory_p.get())), *(static_cast(dst_memory_p.get()))); dev_ctx_.SetBlob(prim_key, softmax_p); } else { @@ -88,8 +87,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { "Fail to find softmax backward primitive in device context"); if (softmax_bwd_p == nullptr) { softmax_bwd_p = std::make_shared( - *softmax_bwd_pd_, *(dst_memory_p.get()), *(diff_dst_memory_p.get()), - *(diff_src_memory_p.get())); + *softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p, + *diff_src_memory_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p); } else { is_reusing_ = true; diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index c39f94637a1abb5bfce9a5428419282f2b870c91..fe4131df2c77ed28cd36f23002d000dac3e8a129 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -160,7 +160,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto get_selected_row = [&](size_t i) -> const SelectedRows& { if (i == 0 && in0) { - return *in0.get(); + return *in0; } else { return in_vars[i]->Get(); } diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index fc3636e0b24765f681d3260b07fe854309774a40..7e1df3b9efec64c3189d2cd80e761994cc061b45 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -262,28 +262,37 @@ Example: For exclusive = false: $$ hstart = i * strides[0] - paddings[0] + $$ + $$ hend = hstart + ksize[0] + $$ + $$ wstart = j * strides[1] - paddings[1] + $$ + $$ wend = wstart + ksize[1] + $$ + $$ Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]} $$ + For exclusive = true: $$ hstart = max(0, i * strides[0] - paddings[0]) + $$ + $$ hend = min(H, hstart + ksize[0]) + $$ + $$ wstart = max(0, j * strides[1] - paddings[1]) + $$ + $$ wend = min(W, wstart + ksize[1]) + $$ + $$ Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)} $$ - For adaptive = true: - $$ - hstart = floor(i * H_{in} / H_{out}) - hend = ceil((i + 1) * H_{in} / H_{out}) - wstart = floor(j * W_{in} / W_{out}) - wend = ceil((j + 1) * W_{in} / W_{out}) - Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)} - $$ )DOC"); } @@ -392,48 +401,68 @@ Example: Output: Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ For ceil_mode = false: - $$ - D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ - H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\ - W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 - $$ + $$ + D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 + $$ + $$ + H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[2]} + 1 + $$ + $$ + W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 + $$ For ceil_mode = true: - $$ - D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0] + strides[0] -1)}{strides[0]} + 1 \\ - H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 \\ - W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1 - $$ + $$ + D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0] + strides[0] -1)}{strides[0]} + 1 + $$ + $$ + H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 + $$ + $$ + W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1 + $$ + For exclusive = false: - $$ - dstart = i * strides[0] - paddings[0] - dend = dstart + ksize[0] - hstart = j * strides[1] - paddings[1] - hend = hstart + ksize[1] - wstart = k * strides[2] - paddings[2] - wend = wstart + ksize[2] - Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{ksize[0] * ksize[1] * ksize[2]} - $$ + $$ + dstart = i * strides[0] - paddings[0] + $$ + $$ + dend = dstart + ksize[0] + $$ + $$ + hstart = j * strides[1] - paddings[1] + $$ + $$ + hend = hstart + ksize[1] + $$ + $$ + wstart = k * strides[2] - paddings[2] + $$ + $$ + wend = wstart + ksize[2] + $$ + $$ + Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{ksize[0] * ksize[1] * ksize[2]} + $$ + For exclusive = true: - $$ - dstart = max(0, i * strides[0] - paddings[0]) - dend = min(D, dstart + ksize[0]) - hstart = max(0, j * strides[1] - paddings[1]) - hend = min(H, hstart + ksize[1]) - wstart = max(0, k * strides[2] - paddings[2]) - wend = min(W, wstart + ksize[2]) - Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} - $$ - - For adaptive = true: - $$ - dstart = floor(i * D_{in} / D_{out}) - dend = ceil((i + 1) * D_{in} / D_{out}) - hstart = floor(j * H_{in} / H_{out}) - hend = ceil((j + 1) * H_{in} / H_{out}) - wstart = floor(k * W_{in} / W_{out}) - wend = ceil((k + 1) * W_{in} / W_{out}) - Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} - $$ + $$ + dstart = max(0, i * strides[0] - paddings[0]) + $$ + $$ + dend = min(D, dstart + ksize[0]) + $$ + $$ + hend = min(H, hstart + ksize[1]) + $$ + $$ + wstart = max(0, k * strides[2] - paddings[2]) + $$ + $$ + wend = min(W, wstart + ksize[2]) + $$ + $$ + Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} + $$ )DOC"); } diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7f7fb26b17c77e6fe87646d3cac20c02c49b52c --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -0,0 +1,225 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sample_logits_op.h" +#include "paddle/fluid/operators/math/sample_prob.h" + +namespace paddle { +namespace operators { + +class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Logits", + "(Tensor, default: Tensor), The unscaled log probabilities " + "which is a 2-D tensor with shape [N x K]. N is the batch_size, " + "and K is the class number."); + AddInput("Labels", + "(Tensor) The ground truth which is a 2-D tensor. Labels is a " + "Tensor with shape [N x NT], where NT is the number of" + "true labels for each example."); + AddInput("CustomizedSamples", + "(Tensor, default: Tensor), A 2-D tensor with shape [N, " + "NT + S]," + " where N is the batch size, NT is the number of true labels " + "and S is the number of negtive sample for each example." + "The first NT elements of each row should be the same with true " + "labels, " + "followed by S custom negtive samples. This tensor" + "is only used when use_customized_samples is true.") + .AsDispensable(); + AddInput( + "CustomizedProbabilities", + "(Tensor, default: Tensor), A 2-D tensor with shape [N, NT + S]." + "The tensor has the same shape with CustomSamples," + "and each element represents probability of element in CustomSamples. " + "This " + "tensor is only used when use_customized_samples is true.") + .AsDispensable(); + AddOutput("Samples", + "(Tensor, default: Tensor), A 2-D tensor with shape [N, " + "NT + S]." + "The outputs value of sampler, including NT true lables and S " + "negetive samples " + "for each example. This will be used in" + "backward calculation.") + .AsIntermediate(); + AddOutput( + "Probabilities", + "(Tensor, default: Tensor), A 2-D tensor with shape [N, NT + S]." + "The probabilites of sampled positive and negtive labels.") + .AsIntermediate(); + AddOutput("SampledLogits", + "(Tensor, default: Tensor), A 2-D tensor with shape" + "[N, NT + S]. The outputs value of sampled logits, which will be" + "used in backward propagation.") + .AsIntermediate(); + AddOutput( + "SampledLabels", + "(Tensor, default: Tensor), A 2-D tensor. The sampled labels" + "with shape [N, NT]. The tonsor contains hard labels as input to " + " softmax op, that is 0, 1, ..., NT-1 because of the first NT elements" + " of Sampels are positive lables."); + AddAttr( + "use_customized_samples", + "An indicator whether to use customized samples with probabilities, if " + "True" + "the operator will use customized samples and customized probabilities" + "otherwise, the operator will generate them by itself.") + .SetDefault(false); + AddAttr( + "uniq", + "An indicator whether to sample non-repetitive negtive labels, if True" + "the operator will sample negtive labels without replacement." + "Otherwise, the operator will sample negtive labels with replacement.") + .SetDefault(true); + AddAttr( + "remove_accidental_hits", + "An indicator whether to remove accidental hits when samples hits true" + "labels, the removal is implemented by subtracting the corresponding" + "logits by float_max to subpress their softmax to be zero.") + .SetDefault(true); + AddAttr("num_samples", "The number of negative samples."); + AddAttr("seed", "Random seed for generating samples").SetDefault(0); + + AddComment(R"DOC( + """ + Computes sampled output training logits and labels suitable for implementing + sampled softmax. + """ + +)DOC"); + } +}; + +class SampleLogitsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Samples"), + "Output(Samples) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Probabilities"), + "Output(Probabilities) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"), + "Output(SampledLogits) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"), + "Output(SampledLabels) should be not null."); + + auto logits_dims = ctx->GetInputDim("Logits"); + auto labels_dims = ctx->GetInputDim("Labels"); + + PADDLE_ENFORCE_EQ( + logits_dims.size(), 2UL, + "The logits of softmax_with_cross_entropy should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL, + "The labels should be a 2-D tensor."); + + const int num_samples = ctx->Attrs().Get("num_samples"); + const int num_sampled_classes = labels_dims[1] + num_samples; + ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); + framework::OpKernelType kt = + framework::OpKernelType(data_type, ctx.device_context()); + return kt; + } +}; + +// UNDERSTAND: InferShape for Grad +class SampleLogitsOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Samples"), + "Input(Samples) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("SampledLogits"), + "Input(SampledLogits) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")), + "Input(SampledLogits@Grad) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), + "Output(Logits@Grad) should be not null."); + + auto logit_dims = ctx->GetInputDim("Logits"); + auto label_dims = ctx->GetInputDim("Labels"); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "The label should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL, + "The logits should be a 2-D tensor."); + + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Logits")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar( + ctx.InputVar(framework::GradVarName("SampledLogits"))); + framework::OpKernelType kt = + framework::OpKernelType(data_type, ctx.device_context()); + return kt; + } +}; + +// UNDERSTAND: what's the rule for making a GradMaker TODO +class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDesc(); + grad_op->SetType("sample_logits_grad"); + grad_op->SetInput("Logits", Input("Logits")); + grad_op->SetInput("Labels", Input("Labels")); + grad_op->SetInput("Samples", Output("Samples")); + grad_op->SetInput("SampledLogits", Output("SampledLogits")); + grad_op->SetInput(framework::GradVarName("SampledLogits"), + OutputGrad("SampledLogits")); + grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sample_logits, ops::SampleLogitsOp, ops::SampleLogitsOpMaker, + ops::SampleLogitsGradMaker); +REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad); +REGISTER_OP_CPU_KERNEL(sample_logits, ops::SampleLogitsKernel, + ops::SampleLogitsKernel); +REGISTER_OP_CPU_KERNEL(sample_logits_grad, ops::SampleLogitsGradKernel, + ops::SampleLogitsGradKernel); diff --git a/paddle/fluid/operators/sample_logits_op.cu b/paddle/fluid/operators/sample_logits_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..fb49793b730f72d66dc846f233bd95ebdab37c52 --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.cu @@ -0,0 +1,257 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/sample_logits_op.h" + +namespace paddle { +namespace operators { + +// UNDERSTAND: something like take_along_axis in numpy. +template +__global__ void GPUTakeAlongD1(size_t size, const int batch_size, + const int array_slice_size, + const int idx_slice_size, const T* p_array, + const int64_t* p_index, T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + int i = idx / idx_slice_size; + auto array_index = p_index[idx]; + p_value[idx] = p_array[i * array_slice_size + array_index]; + } +} + +// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate +// indices, scatter is done in += way. +template +__global__ void GPUPutAlongD1(size_t size, const int batch_size, + const int array_slice_size, + const int idx_slice_size, T* p_array, + const int64_t* p_index, const T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + // size == batch_size + for (; idx < size; idx += step_size) { + int i = idx; + for (int j = 0; j < idx_slice_size; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_array[i * array_slice_size + array_index] += + p_value[i * idx_slice_size + j]; + } + } +} + +// UNDERSTAND: set label as 0,1,...,num_true-1 +template +__global__ void GPUSetLabel(size_t size, const int num_true, int64_t* p_array) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + p_array[idx] = idx % num_true; + } +} + +// UNDERSTAND: compute accidentdal hits from samples and minus corresponding +// logits by a float max, here 1e20 +template +__global__ void gpu_compute_remove_accidental_hits(const int size, + const int num_true, + const int idx_slice_size, + const int64_t* p_index, + T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + int i = idx / idx_slice_size; + if (idx % idx_slice_size < num_true) continue; + for (int j = 0; j < num_true; ++j) { + const auto true_idx = i * idx_slice_size + j; + if (p_index[true_idx] == p_index[idx]) { + p_value[idx] -= 1e20; + break; + } + } + } +} + +template +class SampleLogitsCUDAKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + // get necessary inputs + const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Labels"); + VLOG(3) << "Enter SampleLogitsCUDAKernel"; + + // get necessary outputs + Tensor* samples = context.Output("Samples"); + Tensor* probabilities = context.Output("Probabilities"); + Tensor* sampled_logits = context.Output("SampledLogits"); + Tensor* sampled_labels = context.Output("SampledLabels"); + + // shapes + const auto batch_size = logits->dims()[0]; + const auto num_classes = logits->dims()[1]; + const auto labels_dim = labels->dims(); + const auto num_true = labels_dim[1]; + const auto samples_dim = samples->dims(); + + // attrs + const auto num_samples = context.Attr("num_samples"); + const bool use_customized_samples = + context.Attr("use_customized_samples"); + const bool uniq = context.Attr("uniq"); + const bool remove_accidental_hits = + context.Attr("remove_accidental_hits"); + + // device contexts + auto& dev_ctx = context.cuda_device_context(); + + // UNDERSTAND: allocate memories for temporaries + sampled_logits->mutable_data(samples_dim, context.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, sampled_logits, static_cast(0)); + + auto sampled_labels_data = + sampled_labels->mutable_data(labels_dim, context.GetPlace()); + int threads = 512; + size_t size = batch_size * num_true; + int grid = (size + threads - 1) / threads; + GPUSetLabel< + T><<>>( + size, num_true, sampled_labels_data); + + if (use_customized_samples) { + const Tensor* customized_samples = + context.Input("CustomizedSamples"); + const Tensor* customized_probabilities = + context.Input("CustomizedProbabilities"); + samples->ShareDataWith(*customized_samples); + probabilities->ShareDataWith(*customized_probabilities); + } else { + samples->mutable_data(context.GetPlace()); + probabilities->mutable_data(samples_dim, context.GetPlace()); + // UNDERSTAND: sampling + const auto seed = context.Attr("seed"); + auto sampler_with_prob = math::GPUSampleWithProb(); + sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq, + num_samples, labels, samples, probabilities); + } + + // UNDERSTAND: gather sampled logits and remove accidental hits if needed + const auto num_take = samples->dims()[1]; + const auto array_dims = logits->dims(); + const auto idx_dims = samples->dims(); + + const T* p_array = logits->data(); + const int64_t* p_index = samples->data(); + T* p_value = sampled_logits->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + // index slice size + const auto idx_slice_size = idx_dims[1]; + + size = batch_size * num_take; + grid = (size + threads - 1) / threads; + GPUTakeAlongD1< + T><<>>( + size, batch_size, array_slice_size, idx_slice_size, p_array, p_index, + p_value); + + if (remove_accidental_hits) { + const size_t size = batch_size * (num_true + num_samples); + int grid = (size + threads - 1) / threads; + gpu_compute_remove_accidental_hits< + T><<>>( + size, num_true, idx_slice_size, p_index, p_value); + } + + // subtracted sampled logits with logQ(y|x) + auto probs = EigenMatrix::From(*probabilities); + auto smp_logits = EigenMatrix::From(*sampled_logits); + smp_logits.device(*dev_ctx.eigen_device()) = + (smp_logits - probs.log().unaryExpr(TolerableValue())) + .unaryExpr(TolerableValue()); + } +}; + +template +class SampleLogitsGradCUDAKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + auto logits_grad = context.Output(framework::GradVarName("Logits")); + const Tensor* samples = context.Input("Samples"); + const Tensor* sampled_logits_grad = + context.Input(framework::GradVarName("SampledLogits")); + logits_grad->mutable_data(context.GetPlace()); + + auto& dev_ctx = context.cuda_device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, logits_grad, static_cast(0)); + + // UNDERSTAND: scatter it back to logit_grad + const auto batch_size = samples->dims()[0]; + const auto num_put = samples->dims()[1]; + const auto array_dims = logits_grad->dims(); + const auto idx_dims = samples->dims(); + + T* p_array = logits_grad->data(); + const int64_t* p_index = samples->data(); + const T* p_value = sampled_logits_grad->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + // index slice size + const auto idx_slice_size = idx_dims[1]; + + int threads = 128; + const size_t size = batch_size; + int grid = (size + threads - 1) / threads; + + GPUPutAlongD1< + T><<>>( + size, batch_size, array_slice_size, idx_slice_size, p_array, p_index, + p_value); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(sample_logits, ops::SampleLogitsCUDAKernel, + ops::SampleLogitsCUDAKernel); +REGISTER_OP_CUDA_KERNEL(sample_logits_grad, + ops::SampleLogitsGradCUDAKernel, + ops::SampleLogitsGradCUDAKernel); diff --git a/paddle/fluid/operators/sample_logits_op.h b/paddle/fluid/operators/sample_logits_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b55a24863cc09d5f80e07aedbbb5b3d9ac99e69e --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.h @@ -0,0 +1,245 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +// UNDERSTAND: something like take_along_axis in numpy. +template +static void CPUTakeAlongD1(const platform::DeviceContext& ctx, + const framework::Tensor& array, + const framework::Tensor& index, + framework::Tensor* value) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); + // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K) + PADDLE_ENFORCE(index.dims().size() == 2 && array.dims().size() == 2 && + index.dims()[0] == array.dims()[0] && + index.dims() == value->dims()); + + const auto batch_size = index.dims()[0]; + const auto num_take = index.dims()[1]; + const auto array_dims = array.dims(); + const auto idx_dims = index.dims(); + + // UNDERSTAND: no allocations here + const T* p_array = array.data(); + const int64_t* p_index = index.data(); + T* p_value = value->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + + // index slice size + const auto idx_slice_size = idx_dims[1]; + const auto value_slice_size = idx_slice_size; + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_take; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_value[i * value_slice_size + j] = + p_array[i * array_slice_size + array_index]; + } + } +} + +// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate +// indices, scatter is done in += way. +template +static void CPUPutAlongD1(const platform::DeviceContext& ctx, + framework::Tensor* array, + const framework::Tensor& index, + const framework::Tensor& value) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); + // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K) + PADDLE_ENFORCE(index.dims().size() == 2 && array->dims().size() == 2 && + index.dims()[0] == array->dims()[0] && + index.dims() == value.dims()); + const auto batch_size = index.dims()[0]; + const auto num_put = index.dims()[1]; + auto array_dims = array->dims(); + auto idx_dims = index.dims(); + + // UNDERSTAND: no allocations here + T* p_array = array->data(); + const int64_t* p_index = index.data(); + const T* p_value = value.data(); + + // slice sizes + const auto array_slice_size = array_dims[1]; + const auto idx_slice_size = idx_dims[1]; + const auto value_slice_size = idx_slice_size; + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_put; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_array[i * array_slice_size + array_index] += + p_value[i * value_slice_size + j]; + } + } +} + +// UNDERSTAND: compute accidentdal hits from samples and minus corresponding +// logits by a float max, here 1e20 +template +static void compute_remove_accidental_hits(const platform::DeviceContext& ctx, + framework::Tensor* sampled_logits, + const framework::Tensor& samples, + const int num_true) { + const auto batch_size = sampled_logits->dims()[0]; + const auto num_sampled_classes = sampled_logits->dims()[1]; + T* sampled_logits_data = sampled_logits->data(); + const auto samples_data = samples.data(); + + std::unordered_set tmp_true_labels; + for (int i = 0; i < batch_size; ++i) { + tmp_true_labels.clear(); + tmp_true_labels.insert(samples_data + i * num_sampled_classes, + samples_data + i * num_sampled_classes + num_true); + for (int j = num_true; j < num_sampled_classes; ++j) { + const auto idx = i * num_sampled_classes + j; + if (tmp_true_labels.find(samples_data[idx]) != tmp_true_labels.end()) + sampled_logits_data[idx] -= 1e20; + } + } +} + +template +class SampleLogitsKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), + "This kernel only runs on CPU."); + VLOG(3) << "Enter SampleLogitsKernel"; + // get necessary inputs + const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Labels"); + + // get necessary outputs + Tensor* samples = context.Output("Samples"); + Tensor* probabilities = context.Output("Probabilities"); + Tensor* sampled_logits = context.Output("SampledLogits"); + Tensor* sampled_labels = context.Output("SampledLabels"); + + // shapes + const auto batch_size = logits->dims()[0]; + const auto num_classes = logits->dims()[1]; + const auto labels_dim = labels->dims(); + const auto num_true = labels_dim[1]; + const auto samples_dim = samples->dims(); + + // attrs + const auto num_samples = context.Attr("num_samples"); + const bool use_customized_samples = + context.Attr("use_customized_samples"); + const bool remove_accidental_hits = + context.Attr("remove_accidental_hits"); + + // device contexts + auto& dev_ctx = + context.template device_context(); + + // UNDERSTAND: allocate memories for temporaries + sampled_logits->mutable_data(samples_dim, context.GetPlace()); + auto sampled_labels_data = + sampled_labels->mutable_data(labels_dim, context.GetPlace()); + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_true; ++j) { + sampled_labels_data[i * num_true + j] = j; + } + } + + if (use_customized_samples) { + const Tensor* customized_samples = + context.Input("CustomizedSamples"); + const Tensor* customized_probabilities = + context.Input("CustomizedProbabilities"); + samples->ShareDataWith(*customized_samples); + probabilities->ShareDataWith(*customized_probabilities); + } else { + samples->mutable_data(context.GetPlace()); + probabilities->mutable_data(samples_dim, context.GetPlace()); + // UNDERSTAND: sampling + const auto seed = context.Attr("seed"); + auto sampler_with_prob = + math::SampleWithProb(); + sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed), + num_samples, labels, samples, probabilities); + } + + // UNDERSTAND: gather sampled logits and remove accidental hits if needed + CPUTakeAlongD1(dev_ctx, *logits, *samples, sampled_logits); + if (remove_accidental_hits) { + compute_remove_accidental_hits(dev_ctx, sampled_logits, *samples, + num_true); + } + + // subtracted sampled logits with logQ(y|x) + auto probs = EigenMatrix::From(*probabilities); + auto smp_logits = EigenMatrix::From(*sampled_logits); + smp_logits.device(*dev_ctx.eigen_device()) = + (smp_logits - probs.log().unaryExpr(TolerableValue())) + .unaryExpr(TolerableValue()); + } +}; + +template +class SampleLogitsGradKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + auto logits_grad = context.Output(framework::GradVarName("Logits")); + const Tensor* samples = context.Input("Samples"); + const Tensor* sampled_logits_grad = + context.Input(framework::GradVarName("SampledLogits")); + logits_grad->mutable_data(context.GetPlace()); + + auto& dev_ctx = + context.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, logits_grad, static_cast(0)); + + // UNDERSTAND: scatter it back to logit_grad + CPUPutAlongD1(dev_ctx, logits_grad, *samples, *sampled_logits_grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ed0dbdeb13ce93926c023f9f435776f1a1839933..920b43b2b1990af58b73888bf7a652d57c20563c 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -394,7 +394,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, int tid = platform::get_cur_thread_id(); - std::lock_guard lock(*p_mutex_.get()); + std::lock_guard lock(*p_mutex_); // Find KeyBlob for current thread auto map_it = pMap->find(tid); @@ -427,7 +427,7 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( int tid = platform::get_cur_thread_id(); - std::lock_guard lock(*p_mutex_.get()); + std::lock_guard lock(*p_mutex_); // Find KeyBlob for current thread firstly auto map_it = pMap->find(tid); diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 52372c25143be251b69d5f41a211ac090bba2063..0179daa55715be9787bc7cc8a693319024d404b7 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -136,7 +136,7 @@ void EnableActivity() { CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER)); CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME)); // We don't track these activities for now. - // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); + CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD)); // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE)); // CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT)); @@ -155,7 +155,7 @@ void DisableActivity() { // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT)); CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER)); CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME)); - // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); + CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_NAME)); // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MARKER)); // CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD)); @@ -212,6 +212,14 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer, memcpy->correlationId, memcpy->bytes); break; } + case CUPTI_ACTIVITY_KIND_MEMSET: { + auto *memset = + reinterpret_cast(record); + tracer->AddKernelRecords("MEMSET", memset->start, memset->end, + memset->deviceId, memset->streamId, + memset->correlationId); + break; + } case CUPTI_ACTIVITY_KIND_DRIVER: { auto *api = reinterpret_cast(record); if (api->start != 0 && api->end != 0) @@ -348,6 +356,8 @@ class DeviceTracerImpl : public DeviceTracer { const std::vector cbids { CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020, + CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020, + CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 #if CUDA_VERSION >= 9000 diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index a5b846f500f3677188b170dda76c65047d628064..a260cda49138580b209e647af459e9392d9f18f1 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -86,8 +86,6 @@ extern void* mklml_dso_handle; __macro(vdPowx); \ __macro(vsInv); \ __macro(vdInv); \ - __macro(vmsErf); \ - __macro(vmdErf); \ __macro(MKL_Set_Num_Threads) MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 269280d604a13a62046fb7811d34b7c69b61b50f..908499e0d8dc679a714a332c8dfe5f16bfbdcd3d 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -548,9 +548,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), "Fail to find convolution primitive in device context"); if (conv_p == nullptr) { - conv_p = std::make_shared(*conv_pd_, *(src_memory_p), - *(weights_memory_p.get()), - *(dst_memory_p.get())); + conv_p = std::make_shared(*conv_pd_, *src_memory_p, + *weights_memory_p, *dst_memory_p); dev_ctx_.SetBlob(prim_key, conv_p); } else { @@ -570,9 +569,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), "Fail to find convolution primitive in device context"); if (conv_p == nullptr) { - conv_p = std::make_shared( - *conv_pd_, *(src_memory_p), *(weights_memory_p.get()), - *(bias_memory_p.get()), *(dst_memory_p.get())); + conv_p = std::make_shared(*conv_pd_, *src_memory_p, + *weights_memory_p, *bias_memory_p, + *dst_memory_p); dev_ctx_.SetBlob(prim_key, conv_p); } else { diff --git a/paddle/fluid/train/demo/demo_trainer.cc b/paddle/fluid/train/demo/demo_trainer.cc index a0757b53f37b29de0b3802c345b1ad9db69f16e9..1087f5672459506cc7b824127cd822c0df7ba566 100644 --- a/paddle/fluid/train/demo/demo_trainer.cc +++ b/paddle/fluid/train/demo/demo_trainer.cc @@ -73,7 +73,7 @@ int main() { PADDLE_ENFORCE_NE(loss_name, "", "loss not found"); // init all parameters - executor.Run(*startup_program.get(), &scope, 0); + executor.Run(*startup_program, &scope, 0); // prepare data auto x_var = scope.Var("x"); @@ -101,7 +101,7 @@ int main() { clock_t t1 = clock(); for (int i = 0; i < 10; ++i) { - executor.Run(*train_program.get(), &scope, 0, false, true); + executor.Run(*train_program, &scope, 0, false, true); std::cout << "step: " << i << " loss: " << loss_var->Get().data()[0] << std::endl; diff --git a/paddle/fluid/train/test_train_recognize_digits.cc b/paddle/fluid/train/test_train_recognize_digits.cc index e8731dd51ad698e53b7f10cc781c52134f2d17a8..a7846da8c191ac96e9ad7fb5b3184518e32120b2 100644 --- a/paddle/fluid/train/test_train_recognize_digits.cc +++ b/paddle/fluid/train/test_train_recognize_digits.cc @@ -74,7 +74,7 @@ void Train() { float first_loss = 0.0; float last_loss = 0.0; for (int i = 0; i < 100; ++i) { - executor.Run(*train_program.get(), &scope, 0, false, true); + executor.Run(*train_program, &scope, 0, false, true); if (i == 0) { first_loss = loss_var->Get().data()[0]; } else if (i == 99) { diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index fa79db19ee895c028f9cad0f4212aaff0e513784..483a7d4f46f16761b683e4cea90f833be1fe944f 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -19,6 +19,7 @@ import sys from .. import compat as cpt from . import core +from . import framework __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy'] @@ -110,6 +111,8 @@ class CompiledProgram(object): self._exec_strategy = ExecutionStrategy() if self._build_strategy is None: self._build_strategy = BuildStrategy() + self._build_strategy.is_distribution = framework.is_pserver_mode( + self._program) return self def with_inference_optimize(self, config): diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 15367c724e5304fed78ef58f8a27932e1d6de318..f01b6ab09437a70835e84d66ab59709039b6f54f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -87,6 +87,15 @@ def _current_expected_place(): return _imperative_current_expected_place_ +def is_pserver_mode(main_program): + main = main_program if main_program \ + else default_main_program() + for op in main.global_block().ops: + if op.type in ["send", "recv"]: + return True + return False + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index 59fe6bbf74b80c2260c5b4881fee8807482c9c68..46640ce37a78f7409af7f82d3302a610ccd366b2 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -17,7 +17,7 @@ import contextlib import sys import numpy as np import collections - +from .. import unique_name from paddle.fluid import core from paddle.fluid import framework from paddle.fluid.imperative import base @@ -26,14 +26,33 @@ __all__ = ['Layer', 'PyLayer'] class Layer(core.Layer): - """Layers composed of operators.""" - - def __init__(self, dtype=core.VarDesc.VarType.FP32, name=None): + """Layers composed of operators. + + Args: + name_scope: prefix name used by the layer to name parameters. + If prefix is "my_model/layer_1", parameter name in MyLayer + can be "my_model/layer_1/MyLayer/w_n", where w is the parameter + base name and n is an unique suffix auto-generated. + dtype: data type for the variables in the layer. + """ + + def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32): + self._full_name = unique_name.generate(name_scope + "/" + + self.__class__.__name__) self._built = False self._dtype = dtype self._parameters = collections.OrderedDict() self._sub_layers = collections.OrderedDict() + def full_name(self): + """Full name for this layers. + + Full name is composed by name_scope + "/" + MyLayer.__class__.__name__ + + Returns full name of this name. + """ + return self._full_name + def parameters(self, include_sublayers=True): """Returns a list of Parameters from current and sub-layers. diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index c86a373ae4a92053538c93386003f9014c32841f..41655c4f54eecec55bd2c7d2b74adb51efa88b61 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -27,6 +27,7 @@ __all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding'] class Conv2D(layers.Layer): def __init__(self, + name_scope, num_channels, num_filters, filter_size, @@ -38,19 +39,17 @@ class Conv2D(layers.Layer): act=None, param_attr=None, bias_attr=None, - name=None, dtype=core.VarDesc.VarType.FP32): assert param_attr is not False, "param_attr should not be False here." - super(Conv2D, self).__init__(name=name, dtype=dtype) + super(Conv2D, self).__init__(name_scope, dtype=dtype) # TODO(minqiyang): Move this to the top. from ..layer_helper import LayerHelper self._helper = LayerHelper( - type(self).__name__, + self.full_name(), param_attr=param_attr, bias_attr=bias_attr, dtype=dtype, - name=name, act=act) self._groups = groups @@ -143,6 +142,7 @@ class Conv2D(layers.Layer): class Pool2D(layers.Layer): def __init__(self, + name_scope, pool_size=-1, pool_type="max", pool_stride=1, @@ -151,7 +151,6 @@ class Pool2D(layers.Layer): use_cudnn=True, ceil_mode=False, exclusive=True, - name=None, dtype=core.VarDesc.VarType.FP32): if pool_type not in ["max", "avg"]: raise ValueError( @@ -166,10 +165,10 @@ class Pool2D(layers.Layer): if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") - super(Pool2D, self).__init__(name=name, dtype=dtype) + super(Pool2D, self).__init__(name_scope, dtype=dtype) from ..layer_helper import LayerHelper - self._helper = LayerHelper(type(self).__name__, dtype=dtype, name=name) + self._helper = LayerHelper(self.full_name(), dtype=dtype) self._pool_type = pool_type self._pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') @@ -205,25 +204,24 @@ class Pool2D(layers.Layer): class FC(layers.Layer): def __init__(self, + name_scope, size, param_attr=None, bias_attr=None, num_flatten_dims=1, dtype=core.VarDesc.VarType.FP32, - act=None, - name=None): - super(FC, self).__init__() + act=None): + super(FC, self).__init__(name_scope) self._size = size self._num_flatten_dims = num_flatten_dims self._dtype = dtype from ..layer_helper import LayerHelper self._helper = LayerHelper( - 'FC', + self.full_name(), param_attr=param_attr, bias_attr=bias_attr, - act=act, - name=name) + act=act) def _build_once(self, input): input_shape = input.shape @@ -282,6 +280,7 @@ class FC(layers.Layer): class BatchNorm(layers.Layer): def __init__(self, + name_scope, num_channels, act=None, is_test=False, @@ -292,22 +291,20 @@ class BatchNorm(layers.Layer): dtype=core.VarDesc.VarType.FP32, data_layout='NCHW', in_place=False, - name=None, moving_mean_name=None, moving_variance_name=None, do_model_average_for_mean_and_var=False, fuse_with_relu=False, use_global_stats=False): - super(BatchNorm, self).__init__() + super(BatchNorm, self).__init__(name_scope) assert bias_attr is not False, "bias_attr should not be False in batch_norm." from ..layer_helper import LayerHelper self._helper = LayerHelper( - 'batch_norm', + self.full_name(), param_attr=param_attr, bias_attr=bias_attr, - name=name, act=act) if dtype == core.VarDesc.VarType.FP16: @@ -419,6 +416,7 @@ class Embedding(layers.Layer): constructor. Args: + name_scope: See base class. size(tuple|list): The shape of the look up table parameter. It should have two elements which indicate the size of the dictionary of embeddings and the size of each embedding vector respectively. @@ -446,6 +444,7 @@ class Embedding(layers.Layer): """ def __init__(self, + name_scope, size, is_sparse=False, is_distributed=False, @@ -453,7 +452,7 @@ class Embedding(layers.Layer): param_attr=None, dtype='float32'): - super(Embedding, self).__init__() + super(Embedding, self).__init__(name_scope) self._size = size self._is_sparse = is_sparse self._is_distributed = is_distributed @@ -468,7 +467,7 @@ class Embedding(layers.Layer): assert self._is_sparse is True and self._is_distributed is False from ..layer_helper import LayerHelper - self._helper = LayerHelper('embedding', param_attr=param_attr) + self._helper = LayerHelper(self.full_name(), param_attr=param_attr) self._w = self._helper.create_parameter( attr=self._param_attr, shape=self._size, diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 7d1636774c6e27ec8090ac01710e23beed5fd0e8..65864ca7e09cd4f0760637198d48154eed025c65 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -34,6 +34,9 @@ class LayerHelper(object): self.kwargs = kwargs self.layer_type = layer_type name = self.kwargs.get('name', None) + # TODO(panyx0718, minqiyang): imperative mode + # can not use both `layer_type` and `name`. Deprecate LayerHelper + # and write a Helper for imperative mode. if name is None: self.kwargs['name'] = unique_name.generate(self.layer_type) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 3b43ae0b9cb63a9f4708a680cb1021d74c197550..61a7d4f31d5245e635e2e1fe33e418ce20e94180 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -545,15 +545,16 @@ def yolov3_loss(x, TypeError: Attr ignore_thresh of yolov3_loss must be a float number Examples: - .. code-block:: python - - x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') - gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') - gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') - anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] - anchors = [0, 1, 2] - loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80, anchors=anchors, - ignore_thresh=0.5, downsample_ratio=32) + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') + gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') + anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchor_mask = [0, 1, 2] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, anchors=anchors, + anchor_mask=anchor_mask, class_num=80, + ignore_thresh=0.7, downsample_ratio=32) """ helper = LayerHelper('yolov3_loss', **locals()) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1a7d076835841e3c1b8a90a43437eff13645fb8a..2315a2d5ccd7598f0b897722cc35192d477e0d5e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -87,6 +87,7 @@ __all__ = [ 'transpose', 'im2sequence', 'nce', + 'sampled_softmax_with_cross_entropy', 'hsigmoid', 'beam_search', 'row_conv', @@ -668,7 +669,11 @@ def dynamic_lstmp(input, candidate_activation='tanh', proj_activation='tanh', dtype='float32', - name=None): + name=None, + h_0=None, + c_0=None, + cell_clip=None, + proj_clip=None): """ **Dynamic LSTMP Layer** @@ -785,6 +790,17 @@ def dynamic_lstmp(input, dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. + h_0(Variable): The initial hidden state is an optional input, default is zero. + This is a tensor with shape (N x D), where N is the + batch size and D is the projection size. + c_0(Variable): The initial cell state is an optional input, default is zero. + This is a tensor with shape (N x D), where N is the + batch size. `h_0` and `c_0` can be NULL but only at the same time. + cell_clip(float): If provided the cell state is clipped + by this value prior to the cell output activation. + proj_clip(float): If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. Returns: tuple: A tuple of two output variable: the projection of hidden state, \ @@ -831,25 +847,41 @@ def dynamic_lstmp(input, batch_hidden = helper.create_variable_for_type_inference(dtype) batch_gate = helper.create_variable_for_type_inference(dtype) batch_cell_pre_act = helper.create_variable_for_type_inference(dtype) + inputs = { + 'Input': input, + 'Weight': weight, + 'ProjWeight': proj_weight, + 'Bias': bias + } + batch_size = input.shape[0] + if h_0: + assert h_0.shape == (batch_size, proj_size), \ + 'The shape of h0 should be (batch_size, %d)' % proj_size + inputs['H0'] = h_0 + if c_0: + assert c_0.shape == (batch_size, size), \ + 'The shape of c0 should be (batch_size, %d)' % size + inputs['C0'] = c_0 + + if cell_clip: + assert cell_clip >= 0, "cell_clip should not be negtive." + if proj_clip: + assert proj_clip >= 0, "proj_clip should not be negtive." helper.append_op( type='lstmp', - inputs={ - 'Input': input, - 'Weight': weight, - 'ProjWeight': proj_weight, - 'Bias': bias - }, + inputs=inputs, outputs={ 'Projection': projection, 'Cell': cell, - 'OrderedP0': ordered_proj0, 'BatchHidden': batch_hidden, 'BatchGate': batch_gate, 'BatchCellPreAct': batch_cell_pre_act }, attrs={ 'use_peepholes': use_peepholes, + 'cell_clip': cell_clip, + 'proj_clip': proj_clip, 'is_reverse': is_reverse, 'gate_activation': gate_activation, 'cell_activation': cell_activation, @@ -2569,7 +2601,27 @@ def adaptive_pool2d(input, require_index=False, name=None): """ - ${comment} + **Adaptive Pool2d Operator** + The adaptive_pool2d operation calculates the output based on the input, pool_size, + pool_type parameters. Input(X) and output(Out) are in NCHW format, where N is batch + size, C is the number of channels, H is the height of the feature, and W is + the width of the feature. Parameters(pool_size) should contain two elements which + represent height and width, respectively. Also the H and W dimensions of output(Out) + is same as Parameter(pool_size). + + For average adaptive pool2d: + + .. math:: + + hstart &= floor(i * H_{in} / H_{out}) + + hend &= ceil((i + 1) * H_{in} / H_{out}) + + wstart &= floor(j * W_{in} / W_{out}) + + wend &= ceil((j + 1) * W_{in} / W_{out}) + + Output(i ,j) &= \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)} Args: input (Variable): The input tensor of pooling operator. The format of @@ -2579,8 +2631,8 @@ def adaptive_pool2d(input, pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain two integers, (pool_size_Height, pool_size_Width). pool_type: ${pooling_type_comment} - require_index (bool): If true, the index of max pooling point along with outputs. - it cannot be set in average pooling type. + require_index (bool): If true, the index of max pooling point will be returned along + with outputs. It cannot be set in average pooling type. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2661,18 +2713,42 @@ def adaptive_pool3d(input, require_index=False, name=None): """ - ${comment} + **Adaptive Pool3d Operator** + The adaptive_pool3d operation calculates the output based on the input, pool_size, + pool_type parameters. Input(X) and output(Out) are in NCDHW format, where N is batch + size, C is the number of channels, D is the depth of the feature, H is the height of + the feature, and W is the width of the feature. Parameters(pool_size) should contain + three elements which represent height and width, respectively. Also the D, H and W + dimensions of output(Out) is same as Parameter(pool_size). + + For average adaptive pool3d: + + .. math:: + + dstart &= floor(i * D_{in} / D_{out}) + + dend &= ceil((i + 1) * D_{in} / D_{out}) + + hstart &= floor(j * H_{in} / H_{out}) + + hend &= ceil((j + 1) * H_{in} / H_{out}) + + wstart &= floor(k * W_{in} / W_{out}) + + wend &= ceil((k + 1) * W_{in} / W_{out}) + + Output(i ,j, k) &= \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} Args: input (Variable): The input tensor of pooling operator. The format of - input tensor is NCHW, where N is batch size, C is - the number of channels, H is the height of the - feature, and W is the width of the feature. + input tensor is NCDHW, where N is batch size, C is + the number of channels, D is the depth of the feature, + H is the height of the feature, and W is the width of the feature. pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, - it must contain two integers, (Depth, Height, Width). + it must contain three integers, (Depth, Height, Width). pool_type: ${pooling_type_comment} - require_index (bool): If true, the index of max pooling point along with outputs. - it cannot be set in average pooling type. + require_index (bool): If true, the index of max pooling point will be returned along + with outputs. It cannot be set in average pooling type. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2709,7 +2785,7 @@ def adaptive_pool3d(input, name='data', shape=[3, 32, 32], dtype='float32') pool_out, mask = fluid.layers.adaptive_pool3d( input=data, - pool_size=[3, 3], + pool_size=[3, 3, 3], pool_type='avg') """ if pool_type not in ["max", "avg"]: @@ -5765,6 +5841,132 @@ def softmax_with_cross_entropy(logits, return loss +def sampled_softmax_with_cross_entropy(logits, + label, + num_samples, + num_true=1, + remove_accidental_hits=True, + use_customized_samples=False, + customized_samples=None, + customized_probabilities=None, + seed=0): + """ + **Sampled Softmax With Cross Entropy Operator.** + + Cross entropy loss with sampled softmax is used as the output layer for + larger output classes extensively. This operator samples a number of samples + for all examples, and computes the softmax normalized values for each + row of the sampled tensor, after which cross-entropy loss is computed. + + Because this operator performs a softmax on logits internally, it expects + unscaled logits. This operator should not be used with the output of + softmax operator since that would produce incorrect results. + + For examples with T true labels (T >= 1), we assume that each true label has + a probability of 1/T. For each sample, S samples are generated using a + log uniform distribution. True labels are concatenated with these samples to + form T + S samples for each example. So, assume the shape of logits is + [N x K], the shape for samples is [N x (T+S)]. For each sampled label, a + probability is calculated, which corresponds to the Q(y|x) in + [Jean et al., 2014](http://arxiv.org/abs/1412.2007). + + Logits are sampled according to the sampled labels. Then if + remove_accidental_hits is True, if a sample[i, j] accidentally hits true + labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to + make its softmax result close to zero. Then sampled logits are subtracted by + logQ(y|x), these sampled logits and re-indexed labels are used to compute + a softmax with cross entropy. + + Args: + logits (Variable): The unscaled log probabilities, which is a 2-D tensor + with shape [N x K]. N is the batch_size, and K is the class number. + label (Variable): The ground truth which is a 2-D tensor. Label is a + Tensor with shape [N x T], where T is the number of true + labels per example. + num_samples (int): The number for each example, num_samples should be + less than the number of class. + num_true(int): The number of target classes per training example. + remove_accidental_hits (bool): A flag indicating whether to remove + accidental hits when sampling. If True and if a sample[i, j] + accidentally hits true labels, then the corresponding + sampled_logits[i, j] is minus by 1e20 to make its softmax result + close to zero. Default is True. + use_customized_samples (bool): Whether to use custom samples and probabities to sample + logits. + customized_samples (Variable): User defined samples, which is a 2-D tensor + with shape [N, T + S]. S is the num_samples, and T is the number of true + labels per example. + customized_probabilities (Variable): User defined probabilities of samples, + a 2-D tensor which has the same shape with customized_samples. + seed (int): The random seed for generating random number, which is used + in the process of sampling. Default is 0. + + Returns: + Variable: Return the cross entropy loss which is a 2-D tensor with shape + [N x 1]. + + Examples: + .. code-block:: python + + logits = fluid.layers.data(name='data', shape=[256], dtype='float32') + label = fluid.layers.data(name='label', shape=[5], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.sampled_softmax_with_cross_entropy( + logits=fc, label=label, num_samples=25) + """ + helper = LayerHelper('sample_logits', **locals()) + samples = helper.create_variable_for_type_inference(dtype='int64') + probabilities = helper.create_variable_for_type_inference( + dtype=logits.dtype) + sampled_logits \ + = helper.create_variable_for_type_inference(dtype=logits.dtype) + sampled_label = helper.create_variable_for_type_inference(dtype='int64') + sampled_softlabel = helper.create_variable_for_type_inference( + dtype=logits.dtype) + + helper.append_op( + type='sample_logits', + inputs={ + 'Logits': logits, + 'Labels': label, + 'CustomizedSamples': customized_samples, + 'CustomizedProbabilities': customized_probabilities + }, + outputs={ + 'Samples': samples, + 'Probabilities': probabilities, + 'SampledLabels': sampled_label, + 'SampledLogits': sampled_logits + }, + attrs={ + 'use_customized_samples': use_customized_samples, + 'uniq': True, + 'remove_accidental_hits': remove_accidental_hits, + 'num_samples': num_samples, + 'seed': seed + }) + loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) + helper.append_op( + type='one_hot', + inputs={'X': sampled_label}, + attrs={'depth': num_samples + 1}, + outputs={'Out': sampled_softlabel}) + + helper.append_op( + type='softmax_with_cross_entropy', + inputs={'Logits': sampled_logits, + 'Label': sampled_softlabel}, + outputs={'Softmax': softmax, + 'Loss': loss}, + attrs={ + 'soft_label': True, + 'ignore_index': False, + 'numeric_stable_mode': False + }) + return loss / num_true + + def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): """ This layer computes the smooth L1 loss for Variable :attr:`x` and :attr:`y`. diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 8586670c2481a0f997e84f33860b6df28b3223ae..648bf69273f7ce31431bf8006c4540580cb94b61 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -29,15 +29,6 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy -def _is_pserver_mode(main_program): - main = main_program if main_program \ - else framework.default_main_program() - for op in main.global_block().ops: - if op.type in ["send", "recv"]: - return True - return False - - class ParallelExecutor(object): """ ParallelExecutor is designed for data parallelism, which focuses on distributing @@ -140,7 +131,7 @@ class ParallelExecutor(object): # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, # num_trainers is 1, so the current fields of build_strategy doesn't tell if # it's distributed model. - build_strategy.is_distribution = _is_pserver_mode( + build_strategy.is_distribution = framework.is_pserver_mode( main_program) or num_trainers > 1 # step4: get main_program, scope, local_scopes diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index bf00698d63624d4e20a0853641219a2735d89d25..caf9750e58889ac40c7cdde022f0b6aa5e77fc42 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -20,10 +20,10 @@ from paddle.fluid.layer_helper import LayerHelper class L1(fluid.imperative.Layer): - def __init__(self): - super(L1, self).__init__() + def __init__(self, prefix): + super(L1, self).__init__(prefix) self._helper = LayerHelper( - 'MyLayer', + self.full_name(), param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.1))) @@ -43,20 +43,20 @@ class L1(fluid.imperative.Layer): class L2(fluid.imperative.Layer): - def __init__(self): - super(L2, self).__init__() - self.layer1 = L1() - self.layer2 = L1() + def __init__(self, prefix): + super(L2, self).__init__(prefix) + self.layer1 = L1(self.full_name()) + self.layer2 = L1(self.full_name()) def forward(self): return self.layer1() + self.layer2() class L3(fluid.imperative.Layer): - def __init__(self): - super(L3, self).__init__() - self.layer1 = L2() - self.layer2 = L2() + def __init__(self, prefix): + super(L3, self).__init__(prefix) + self.layer1 = L2(self.full_name()) + self.layer2 = L2(self.full_name()) def forward(self): return self.layer1() + self.layer2() @@ -65,16 +65,23 @@ class L3(fluid.imperative.Layer): class TestBaseLayer(unittest.TestCase): def test_one_level(self): with fluid.imperative.guard(): - l = L1() + l = L1('test_one_level') ret = l() - self.assertEqual(l.w1.name, "MyLayer_0.w_0") - self.assertEqual(l.w2.name, "MyLayer_0.w_1") + self.assertEqual(l.w1.name, "test_one_level/L1_0_0.w_0") + self.assertEqual(l.w2.name, "test_one_level/L1_0_0.w_1") self.assertTrue(np.allclose(ret._numpy(), 0.2 * np.ones([2, 2]))) def test_three_level(self): with fluid.imperative.guard(): - l = L3() + l = L3('test_three_level') + names = [p.name for p in l.parameters()] ret = l() + self.assertEqual(names[0], "test_three_level/L3_0/L2_0/L1_0_0.w_0") + self.assertEqual(names[1], "test_three_level/L3_0/L2_0/L1_0_0.w_1") + self.assertEqual(names[2], "test_three_level/L3_0/L2_0/L1_1_0.w_0") + self.assertEqual(names[3], "test_three_level/L3_0/L2_0/L1_1_0.w_1") + self.assertEqual(names[4], "test_three_level/L3_0/L2_1/L1_0_0.w_0") + self.assertEqual(names[5], "test_three_level/L3_0/L2_1/L1_0_0.w_1") self.assertTrue(np.allclose(ret._numpy(), 0.8 * np.ones([2, 2]))) diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index c54e998ea875e1bd27f9816f88db0e38bc488459..dae0c466ee5ea919688b29100f77f17f5f3b8c6d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -15,7 +15,6 @@ import contextlib import unittest import numpy as np -import sys import paddle.fluid as fluid from paddle.fluid import core @@ -24,8 +23,8 @@ from test_imperative_base import new_program_scope class MyLayer(fluid.imperative.Layer): - def __init__(self): - super(MyLayer, self).__init__() + def __init__(self, name_scope): + super(MyLayer, self).__init__(name_scope) def forward(self, inputs): x = fluid.layers.relu(inputs) @@ -50,12 +49,14 @@ class MyPyLayer(fluid.imperative.PyLayer): class MLP(fluid.imperative.Layer): - def __init__(self): - super(MLP, self).__init__() - self._fc1 = FC(3, + def __init__(self, name_scope): + super(MLP, self).__init__(name_scope) + self._fc1 = FC(self.full_name(), + 3, fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.1))) - self._fc2 = FC(4, + self._fc2 = FC(self.full_name(), + 4, fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.1))) @@ -67,8 +68,9 @@ class MLP(fluid.imperative.Layer): class SimpleRNNCell(fluid.imperative.Layer): - def __init__(self, step_input_size, hidden_size, output_size, param_attr): - super(SimpleRNNCell, self).__init__() + def __init__(self, name_scope, step_input_size, hidden_size, output_size, + param_attr): + super(SimpleRNNCell, self).__init__(name_scope) self.step_input_size = step_input_size self.hidden_size = hidden_size self.output_size = output_size @@ -158,10 +160,11 @@ class SimpleRNNCell(fluid.imperative.Layer): class SimpleRNN(fluid.imperative.Layer): - def __init__(self): - super(SimpleRNN, self).__init__() + def __init__(self, name_scope): + super(SimpleRNN, self).__init__(name_scope) self.seq_len = 4 self._cell = SimpleRNNCell( + self.full_name(), 3, 3, 3, @@ -205,7 +208,7 @@ class TestImperative(unittest.TestCase): with fluid.imperative.guard(): cl = core.Layer() cl.forward([]) - l = fluid.imperative.Layer() + l = fluid.imperative.Layer("l") self.assertRaises(NotImplementedError, l.forward, []) def test_pylayer_func_id(self): @@ -281,7 +284,7 @@ class TestImperative(unittest.TestCase): np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32) with fluid.imperative.guard(): var_inp = fluid.imperative.base.to_variable(np_inp) - l = MyLayer() + l = MyLayer("my_layer") x = l(var_inp)[0] self.assertIsNotNone(x) dy_out = x._numpy() @@ -291,7 +294,7 @@ class TestImperative(unittest.TestCase): with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[3], append_batch_size=False) - l = MyLayer() + l = MyLayer("my_layer") x = l(inp)[0] param_grads = fluid.backward.append_backward( x, parameter_list=[l._x_for_debug.name])[0] @@ -309,7 +312,7 @@ class TestImperative(unittest.TestCase): np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) with fluid.imperative.guard(): var_inp = fluid.imperative.base.to_variable(np_inp) - mlp = MLP() + mlp = MLP("mlp") out = mlp(var_inp) dy_out = out._numpy() out._backward() @@ -318,7 +321,7 @@ class TestImperative(unittest.TestCase): with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[2, 2], append_batch_size=False) - mlp = MLP() + mlp = MLP("mlp") out = mlp(inp) param_grads = fluid.backward.append_backward( out, parameter_list=[mlp._fc1._w.name])[0] @@ -334,10 +337,10 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.allclose(dy_grad, static_grad)) params = mlp.parameters(True) - self.assertEqual("FC_0.w_0", params[0].name) - self.assertEqual("FC_0.b_0", params[1].name) - self.assertEqual("FC_1.w_0", params[2].name) - self.assertEqual("FC_1.b_0", params[3].name) + self.assertEqual("mlp/MLP_0/FC_0_0.w_0", params[0].name) + self.assertEqual("mlp/MLP_0/FC_0_0.b_0", params[1].name) + self.assertEqual("mlp/MLP_0/FC_1_0.w_0", params[2].name) + self.assertEqual("mlp/MLP_0/FC_1_0.b_0", params[3].name) self.assertEqual(len(params), 4) sublayers = mlp.sublayers(True) @@ -353,7 +356,7 @@ class TestImperative(unittest.TestCase): with fluid.imperative.guard(): var_inp = fluid.imperative.base.to_variable(np_inp) var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3]) - simple_rnn = SimpleRNN() + simple_rnn = SimpleRNN("simple_rnn") outs, pre_hiddens = simple_rnn.forward(var_inp) dy_out = outs[3]._numpy() outs[3]._backward() @@ -364,7 +367,7 @@ class TestImperative(unittest.TestCase): with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[1, 4, 3], append_batch_size=False) - simple_rnn = SimpleRNN() + simple_rnn = SimpleRNN("simple_rnn") outs, pre_hiddens = simple_rnn(inp) param_grads = fluid.backward.append_backward(outs[3]) exe = fluid.Executor(fluid.CPUPlace()) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gan.py b/python/paddle/fluid/tests/unittests/test_imperative_gan.py index 33c196d1ab52b393491561e75054e6c323fce18d..a80202d6dddacaa4cb6fa3efd3c3dfd5b0ab4400 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gan.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gan.py @@ -28,10 +28,10 @@ from paddle.fluid.imperative.base import to_variable class Discriminator(fluid.imperative.Layer): - def __init__(self): - super(Discriminator, self).__init__() - self._fc1 = FC(size=32, act='elu', name="d_fc1") - self._fc2 = FC(size=1, name="d_fc2") + def __init__(self, name_scope): + super(Discriminator, self).__init__(name_scope) + self._fc1 = FC(self.full_name(), size=32, act='elu') + self._fc2 = FC(self.full_name(), size=1) def forward(self, inputs): x = self._fc1(inputs) @@ -39,11 +39,11 @@ class Discriminator(fluid.imperative.Layer): class Generator(fluid.imperative.Layer): - def __init__(self): - super(Generator, self).__init__() - self._fc1 = FC(size=64, act='elu', name="g_fc1") - self._fc2 = FC(size=64, act='elu', name="g_fc2") - self._fc3 = FC(size=1, name="g_fc3") + def __init__(self, name_scope): + super(Generator, self).__init__(name_scope) + self._fc1 = FC(self.full_name(), size=64, act='elu') + self._fc2 = FC(self.full_name(), size=64, act='elu') + self._fc3 = FC(self.full_name(), size=1) def forward(self, inputs): x = self._fc1(inputs) @@ -65,8 +65,8 @@ class TestImperativeMnist(unittest.TestCase): scope = fluid.core.Scope() with new_program_scope( main=discriminate_p, startup=startup, scope=scope): - discriminator = Discriminator() - generator = Generator() + discriminator = Discriminator("d") + generator = Generator("g") img = fluid.layers.data( name="img", shape=[2, 1], append_batch_size=False) @@ -93,8 +93,8 @@ class TestImperativeMnist(unittest.TestCase): sgd.minimize(d_loss) with new_program_scope(main=generate_p, startup=startup, scope=scope): - discriminator = Discriminator() - generator = Generator() + discriminator = Discriminator("d") + generator = Generator("g") noise = fluid.layers.data( name="noise", shape=[2, 2], append_batch_size=False) @@ -134,8 +134,8 @@ class TestImperativeMnist(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - discriminator = Discriminator() - generator = Generator() + discriminator = Discriminator("d") + generator = Generator("g") sgd = SGDOptimizer(learning_rate=1e-3) d_real = discriminator(to_variable(np.ones([2, 1], np.float32))) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 08b155acc657c3a4a73f5b1d72ac356fc7e83a58..780c6a6be567c9f60f472c27cebd5300d56eb378 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -28,6 +28,7 @@ from test_imperative_base import new_program_scope class SimpleImgConvPool(fluid.imperative.Layer): def __init__(self, + name_scope, num_channels, num_filters, filter_size, @@ -44,9 +45,10 @@ class SimpleImgConvPool(fluid.imperative.Layer): use_cudnn=False, param_attr=None, bias_attr=None): - super(SimpleImgConvPool, self).__init__() + super(SimpleImgConvPool, self).__init__(name_scope) self._conv2d = Conv2D( + self.full_name(), num_channels=num_channels, num_filters=num_filters, filter_size=filter_size, @@ -59,6 +61,7 @@ class SimpleImgConvPool(fluid.imperative.Layer): use_cudnn=use_cudnn) self._pool2d = Pool2D( + self.full_name(), pool_size=pool_size, pool_type=pool_type, pool_stride=pool_stride, @@ -73,19 +76,20 @@ class SimpleImgConvPool(fluid.imperative.Layer): class MNIST(fluid.imperative.Layer): - def __init__(self, param_attr=None, bias_attr=None): - super(MNIST, self).__init__() + def __init__(self, name_scope, param_attr=None, bias_attr=None): + super(MNIST, self).__init__(name_scope) self._simple_img_conv_pool_1 = SimpleImgConvPool( - 1, 20, 5, 2, 2, act="relu") + self.full_name(), 1, 20, 5, 2, 2, act="relu") self._simple_img_conv_pool_2 = SimpleImgConvPool( - 20, 50, 5, 2, 2, act="relu") + self.full_name(), 20, 50, 5, 2, 2, act="relu") pool_2_shape = 50 * 4 * 4 SIZE = 10 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 - self._fc = FC(10, + self._fc = FC(self.full_name(), + 10, param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.NormalInitializer( loc=0.0, scale=scale)), @@ -106,7 +110,7 @@ class TestImperativeMnist(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - mnist = MNIST() + mnist = MNIST("mnist") sgd = SGDOptimizer(learning_rate=1e-3) train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=128) @@ -150,7 +154,7 @@ class TestImperativeMnist(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) - mnist = MNIST() + mnist = MNIST("mnist") sgd = SGDOptimizer(learning_rate=1e-3) train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=128) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 7cf3bf13d2072bac3bf7bd74760c56e7ac12b8a7..c8e42d5ede57896b0d5c09a2334709ced2d16a3f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -28,12 +28,13 @@ from paddle.fluid.backward import append_backward class SimpleLSTMRNN(fluid.imperative.Layer): def __init__(self, + name_scope, hidden_size, num_steps, num_layers=2, init_scale=0.1, dropout=None): - super(SimpleLSTMRNN, self).__init__() + super(SimpleLSTMRNN, self).__init__(name_scope) self._hidden_size = hidden_size self._num_layers = num_layers self._init_scale = init_scale @@ -130,13 +131,14 @@ class SimpleLSTMRNN(fluid.imperative.Layer): class PtbModel(fluid.imperative.Layer): def __init__(self, + name_scope, hidden_size, vocab_size, num_layers=2, num_steps=20, init_scale=0.1, dropout=None): - super(PtbModel, self).__init__() + super(PtbModel, self).__init__(name_scope) self.hidden_size = hidden_size self.vocab_size = vocab_size self.init_scale = init_scale @@ -146,12 +148,14 @@ class PtbModel(fluid.imperative.Layer): from paddle.fluid.layer_helper import LayerHelper self._helper = LayerHelper('PtbModel', act="tanh") self.simple_lstm_rnn = SimpleLSTMRNN( + self.full_name(), hidden_size, num_steps, num_layers=num_layers, init_scale=init_scale, dropout=dropout) self.embedding = Embedding( + self.full_name(), size=[vocab_size, hidden_size], dtype='float32', is_sparse=False, @@ -226,6 +230,7 @@ class TestImperativePtbRnn(unittest.TestCase): fluid.default_main_program().random_seed = seed # TODO: marsyang1993 Change seed to ptb_model = PtbModel( + "ptb_model", hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers, @@ -265,6 +270,7 @@ class TestImperativePtbRnn(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed ptb_model = PtbModel( + "ptb_model", hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 128d18621db8374c6c385dddbefc0d29e760a02f..0e134742a7e80c462206072644bb4bf196397b38 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -70,15 +70,17 @@ def optimizer_setting(params): class ConvBNLayer(fluid.imperative.Layer): def __init__(self, + name_scope, num_channels, num_filters, filter_size, stride=1, groups=1, act=None): - super(ConvBNLayer, self).__init__() + super(ConvBNLayer, self).__init__(name_scope) self._conv = Conv2D( + self.full_name(), num_channels=num_channels, num_filters=num_filters, filter_size=filter_size, @@ -88,7 +90,7 @@ class ConvBNLayer(fluid.imperative.Layer): act=None, bias_attr=None) - self._batch_norm = BatchNorm(num_filters, act=act) + self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act) def forward(self, inputs): y = self._conv(inputs) @@ -98,21 +100,29 @@ class ConvBNLayer(fluid.imperative.Layer): class BottleneckBlock(fluid.imperative.Layer): - def __init__(self, num_channels, num_filters, stride, shortcut=True): - super(BottleneckBlock, self).__init__() + def __init__(self, + name_scope, + num_channels, + num_filters, + stride, + shortcut=True): + super(BottleneckBlock, self).__init__(name_scope) self.conv0 = ConvBNLayer( + self.full_name(), num_channels=num_channels, num_filters=num_filters, filter_size=1, act='relu') self.conv1 = ConvBNLayer( + self.full_name(), num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, act='relu') self.conv2 = ConvBNLayer( + self.full_name(), num_channels=num_filters, num_filters=num_filters * 4, filter_size=1, @@ -120,6 +130,7 @@ class BottleneckBlock(fluid.imperative.Layer): if not shortcut: self.short = ConvBNLayer( + self.full_name(), num_channels=num_channels, num_filters=num_filters * 4, filter_size=1, @@ -141,13 +152,13 @@ class BottleneckBlock(fluid.imperative.Layer): y = fluid.layers.elementwise_add(x=short, y=conv2) - layer_helper = LayerHelper('elementwise_add_activation', act='relu') + layer_helper = LayerHelper(self.full_name(), act='relu') return layer_helper.append_activation(y) class ResNet(fluid.imperative.Layer): - def __init__(self, layers=50, class_dim=102): - super(ResNet, self).__init__() + def __init__(self, name_scope, layers=50, class_dim=102): + super(ResNet, self).__init__(name_scope) self.layers = layers supported_layers = [50, 101, 152] @@ -163,9 +174,18 @@ class ResNet(fluid.imperative.Layer): num_filters = [64, 128, 256, 512] self.conv = ConvBNLayer( - num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') + self.full_name(), + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') self.pool2d_max = Pool2D( - pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + self.full_name(), + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') self.bottleneck_block_list = [] num_channels = 64 @@ -175,6 +195,7 @@ class ResNet(fluid.imperative.Layer): bottleneck_block = self.add_sublayer( 'bb_%d_%d' % (block, i), BottleneckBlock( + self.full_name(), num_channels=num_channels, num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, @@ -184,12 +205,13 @@ class ResNet(fluid.imperative.Layer): shortcut = True self.pool2d_avg = Pool2D( - pool_size=7, pool_type='avg', global_pooling=True) + self.full_name(), pool_size=7, pool_type='avg', global_pooling=True) import math stdv = 1.0 / math.sqrt(2048 * 1.0) - self.out = FC(size=class_dim, + self.out = FC(self.full_name(), + size=class_dim, act='softmax', param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.Uniform(-stdv, stdv))) @@ -214,7 +236,7 @@ class TestImperativeResnet(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - resnet = ResNet() + resnet = ResNet("resnet") optimizer = optimizer_setting(train_parameters) np.random.seed(seed) import random @@ -275,7 +297,7 @@ class TestImperativeResnet(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) - resnet = ResNet() + resnet = ResNet("resnet") optimizer = optimizer_setting(train_parameters) np.random.seed(seed) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e7bc1601a54c8615e0e787d74145aa4987b6cb88..30194f8cacfea2361ffe4afe537287a261cf470b 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -374,6 +374,17 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_sampled_softmax_with_cross_entropy(self): + program = Program() + with program_guard(program): + logits = layers.data(name='Logits', shape=[256], dtype='float64') + label = layers.data(name='Label', shape=[1], dtype='int64') + num_samples = 25 + output = layers.sampled_softmax_with_cross_entropy(logits, label, + num_samples) + self.assertIsNotNone(output) + print(str(program)) + @decorators.prog_scope() def test_nce(self): window_size = 5 diff --git a/python/paddle/fluid/tests/unittests/test_lstmp_op.py b/python/paddle/fluid/tests/unittests/test_lstmp_op.py index 9c3ec45515ffe0a07541fd9cfb7e92b079264071..0645cfedb8089f5618c54672cac91343e5dee285 100644 --- a/python/paddle/fluid/tests/unittests/test_lstmp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstmp_op.py @@ -36,12 +36,14 @@ def lstmp( w_b=None, # 1 x 4D w_c=None, # 1 x 3D is_reverse=False, + proj_clip=0.0, + cell_clip=0.0, act_gate=None, act_cell=None, act_cand=None, act_proj=None): - def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand, - act_proj): + def _step(x, w_r, w_rh, w_c, r_pre, c_pre, proj_clip, cell_clip, act_gate, + act_cell, act_cand, act_proj): g = np.dot(r_pre, w_r) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) @@ -55,6 +57,17 @@ def lstmp( g_f = act_gate(g_f + w_fc * c_pre) # 1 x D c = g_f * c_pre + g_i * act_cand(c) # 1 x D + def array_clip(a, clip): + size = np.prod(a.shape) + new_a = np.reshape(a, (size)) + for i in range(size): + new_a[i] = max(new_a[i], -1.0 * clip) + new_a[i] = min(new_a[i], clip) + new_a = np.reshape(new_a, a.shape) + return new_a + + if cell_clip > 0.0: + c = array_clip(c, cell_clip) if w_c is None: g_o = act_gate(g_o) # 1 x D else: @@ -64,6 +77,8 @@ def lstmp( # projection r = np.dot(h, w_rh) r = act_proj(r) + if proj_clip > 0.0: + r = array_clip(r, proj_clip) return r, c def _reverse(x, offset): @@ -87,13 +102,13 @@ def lstmp( # compute one sequence seq_len = lod[0][i] x = input[offset[i]:offset[i + 1], :] - r_pre = np.dot(h0[i], w_rh) # 1 x P - r_pre = act_proj(r_pre) + r_pre = h0[i] c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, - act_cell, act_cand, act_proj) + r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, proj_clip, + cell_clip, act_gate, act_cell, act_cand, + act_proj) projection.append(r_pre.flatten()) cell.append(c_pre.flatten()) @@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): T = sum(self.lod[0]) N = len(self.lod[0]) - x = np.random.normal(size=(T, 4 * self.D)).astype('float64') if self.has_initial_state: - h0 = np.random.normal(size=(N, self.D)).astype('float64') + h0 = np.random.normal(size=(N, self.P)).astype('float64') c0 = np.random.normal(size=(N, self.D)).astype('float64') else: - h0 = np.zeros((N, self.D)).astype('float64') + h0 = np.zeros((N, self.P)).astype('float64') c0 = np.zeros((N, self.D)).astype('float64') w = np.random.normal(size=(self.P, 4 * self.D)).astype('float64') if self.use_peepholes: @@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') + proj_clip = 0.1 + cell_clip = 0.1 r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, - ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], - ACTIVATION[self.act_cand], ACTIVATION[self.act_proj]) + proj_clip, cell_clip, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], ACTIVATION[self.act_cand], + ACTIVATION[self.act_proj]) self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} @@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp): self.attrs = { 'use_peepholes': self.use_peepholes, 'is_reverse': self.is_reverse, + 'proj_clip': proj_clip, + 'cell_clip': cell_clip, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, 'candidate_activation': self.act_cand, @@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp): def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'], - max_relative_error=1e-2) + max_relative_error=1e-2, + numeric_grad_delta=0.0000005) class TestLstmpOpHasInitial(TestLstmpOp): @@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp): def test_check_grad(self): # TODO(qingqing) remove folowing lines after the check_grad is refined. N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'], ['Projection'], + numeric_grad_delta=0.0000005, max_relative_error=1e-2) def test_check_grad_ingore_bias(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'ProjWeight', 'Weight'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'ProjWeight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Weight')) def test_check_grad_ingore_proj_weight(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('ProjWeight')) def test_check_grad_ingore_input(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Weight', 'ProjWeight', 'Bias'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('H0')) def test_check_grad_ingore_c0(self): N = len(self.lod[0]) - self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchCellPreAct'] = np.zeros( @@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp): self.check_grad( ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'], max_relative_error=1e-2, + numeric_grad_delta=0.0000005, no_grad_set=set('C0'))