diff --git a/README.md b/README.md index a67cb8ad439f462c361cb6bac2449c3a4b042126..60ffbe728178705b1734e682868614025214c2a4 100644 --- a/README.md +++ b/README.md @@ -76,33 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85 ## Installation -It is recommended to check out the -[Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/docker_install_en.html) -before looking into the -[build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html). +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website. ## Documentation -We provide [English](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) and -[Chinese](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation. -- [Deep Learning 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html) +- [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/cluster/index_en.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html) You can run distributed training jobs on MPI clusters. -- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/cluster/multi_cluster/k8s_en.html) - - You can also run distributed training jobs on Kubernetes clusters. - -- [Python API](http://www.paddlepaddle.org/docs/develop/api/en/overview.html) +- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html) Our new API enables much shorter programs. -- [How to Contribute](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/dev/contribute_to_paddle_en.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html) We appreciate your contributions! diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b6ae930b7155d15d24b287cc3eed50f2aeaa5599..bb5f2894c08b5d8941ad8914f6b83280aa053e37 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -172,6 +172,7 @@ paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None)) paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)) +paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index da163835e8652ae479121bd67f2eed77332b2740..a068d3543d9d2abec203f86362a8be5ba135d04d 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -56,76 +56,5 @@ struct RWLock { }; #endif -class RWLockGuard { - public: - enum Status { kUnLock, kWRLock, kRDLock }; - - RWLockGuard(RWLock* rw_lock, Status init_status) - : lock_(rw_lock), status_(Status::kUnLock) { - switch (init_status) { - case Status::kRDLock: { - RDLock(); - break; - } - case Status::kWRLock: { - WRLock(); - break; - } - case Status::kUnLock: { - break; - } - } - } - - void WRLock() { - switch (status_) { - case Status::kUnLock: { - lock_->WRLock(); - status_ = Status::kWRLock; - break; - } - case Status::kWRLock: { - break; - } - case Status::kRDLock: { - PADDLE_THROW( - "Please unlock read lock first before invoking write lock."); - break; - } - } - } - - void RDLock() { - switch (status_) { - case Status::kUnLock: { - lock_->RDLock(); - status_ = Status::kRDLock; - break; - } - case Status::kRDLock: { - break; - } - case Status::kWRLock: { - PADDLE_THROW( - "Please unlock write lock first before invoking read lock."); - break; - } - } - } - - void UnLock() { - if (status_ != Status::kUnLock) { - lock_->UNLock(); - status_ = Status::kUnLock; - } - } - - ~RWLockGuard() { UnLock(); } - - private: - RWLock* lock_; - Status status_; -}; - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 4a7a6bcf7154d5680de751e3c933be46fb09fd74..22cbf680c0670552fb014043c69fcadc56863529 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { output_channels / groups * output_height * output_width * output_depth; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; if (user_workspace_size > 0) { @@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); + // Allocate on GPU memory + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_filter_desc, filter_data + i * group_offset_filter, - cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, - &beta, cudnn_output_desc, output_data + i * group_offset_out)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_filter_desc, filter_data + i * group_offset_filter, + cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, + &beta, cudnn_output_desc, output_data + i * group_offset_out)); } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -311,7 +314,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { cudnn_filter_desc, filter_algo, &tmp_size)); workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); } - + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { @@ -319,15 +326,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset input_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_input_desc, input_grad_data + i * group_offset_in)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + i * group_offset_in)); } } // ------------------- cudnn conv backward filter --------------------- @@ -335,17 +339,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, - input_data + i * group_offset_in, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + i * group_offset_filter)); } } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 73831611d01b8c5b8d2d9f7f15634a0094e4a608..82fff68e7557b3f0b44e6faf2a50e5a0ecbba589 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { conv_desc.descriptor(paddings, strides, dilations); // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes; if (user_workspace_size > 0) { @@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); + // Allocate on GPU memory + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv transpose forward --------------------- int input_offset = input->numel() / input->dims()[0] / groups; int output_offset = output->numel() / output->dims()[0] / groups; int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, - cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, - algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_output_desc, output_data + output_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, + cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, + algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_output_desc, output_data + output_offset * g)); } + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; @@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { std::max(workspace_size_in_bytes, bwd_filter_ws_size); } + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. int input_offset = input->numel() / input->dims()[0] / groups; @@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_filter_desc, - filter_data + filter_offset * g, cudnn_conv_desc, data_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, - input_grad_data + input_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_filter_desc, + filter_data + filter_offset * g, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + input_offset * g)); } } @@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter for (int g = 0; g < groups; g++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_output_desc, - output_grad_data + output_grad_offset * g, cudnn_input_desc, - input_data + input_offset * g, cudnn_conv_desc, filter_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + filter_offset * g)); - }; - dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_input_desc, + input_data + input_offset * g, cudnn_conv_desc, filter_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + filter_offset * g)); } } + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); } }; diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0dee1781623d5a62830545c0952e5aadbe37accb --- /dev/null +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +/* + * transform that computes target bounding-box regression deltas + * given proposal boxes and ground-truth boxes. + */ +template +inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes, + const framework::Tensor& gt_boxes, const T* weights, + const bool normalized, framework::Tensor* box_delta) { + auto ex_boxes_et = framework::EigenTensor::From(ex_boxes); + auto gt_boxes_et = framework::EigenTensor::From(gt_boxes); + auto trg = framework::EigenTensor::From(*box_delta); + T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y; + for (int64_t i = 0; i < box_num; ++i) { + ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + (normalized == false); + ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + (normalized == false); + ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w; + ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h; + + gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + (normalized == false); + gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + (normalized == false); + gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w; + gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h; + + trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w; + trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h; + trg(i, 2) = std::log(gt_w / ex_w); + trg(i, 3) = std::log(gt_h / ex_h); + + if (weights) { + trg(i, 0) = trg(i, 0) / weights[0]; + trg(i, 1) = trg(i, 1) / weights[1]; + trg(i, 2) = trg(i, 2) / weights[2]; + trg(i, 3) = trg(i, 3) / weights[3]; + } + } +} + +template +void Gather(const T* in, const int in_stride, const int* index, const int num, + T* out) { + const int stride_bytes = in_stride * sizeof(T); + for (int i = 0; i < num; ++i) { + int id = index[i]; + memcpy(out + i * in_stride, in + id * in_stride, stride_bytes); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 0571c46f6be99c9a06b7dd2abb310eeda506ecd5..be06dc19743cfa6f093bcb3f4e9f91af315d4211 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/math_function.h" @@ -133,31 +134,6 @@ void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes, } } -template -void BoxToDelta(int box_num, const Tensor& ex_boxes, const Tensor& gt_boxes, - const std::vector& weights, Tensor* box_delta) { - auto ex_boxes_et = framework::EigenTensor::From(ex_boxes); - auto gt_boxes_et = framework::EigenTensor::From(gt_boxes); - auto box_delta_et = framework::EigenTensor::From(*box_delta); - T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y; - for (int64_t i = 0; i < box_num; ++i) { - ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + 1; - ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + 1; - ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w; - ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h; - - gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + 1; - gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + 1; - gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w; - gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h; - - box_delta_et(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]; - box_delta_et(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]; - box_delta_et(i, 2) = log(gt_w / ex_w) / ex_w / weights[2]; - box_delta_et(i, 3) = log(gt_h / ex_h) / ex_h / weights[3]; - } -} - template std::vector> SampleFgBgGt( const platform::CPUDeviceContext& context, Tensor* iou, @@ -243,12 +219,11 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, Tensor* sampled_labels, Tensor* sampled_gts) { int fg_num = fg_inds.size(); int bg_num = bg_inds.size(); - int gt_num = fg_num + bg_num; Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t; int* fg_inds_data = fg_inds_t.mutable_data({fg_num}, context.GetPlace()); int* bg_inds_data = bg_inds_t.mutable_data({bg_num}, context.GetPlace()); int* gt_box_inds_data = - gt_box_inds_t.mutable_data({gt_num}, context.GetPlace()); + gt_box_inds_t.mutable_data({fg_num}, context.GetPlace()); int* gt_label_inds_data = gt_label_inds_t.mutable_data({fg_num}, context.GetPlace()); std::copy(fg_inds.begin(), fg_inds.end(), fg_inds_data); @@ -303,18 +278,20 @@ std::vector SampleRoisForOneImage( // Gather boxes and labels Tensor sampled_boxes, sampled_labels, sampled_gts; - int boxes_num = fg_inds.size() + bg_inds.size(); + int fg_num = fg_inds.size(); + int bg_num = bg_inds.size(); + int boxes_num = fg_num + bg_num; framework::DDim bbox_dim({boxes_num, kBoxDim}); sampled_boxes.mutable_data(bbox_dim, context.GetPlace()); sampled_labels.mutable_data({boxes_num}, context.GetPlace()); - sampled_gts.mutable_data(bbox_dim, context.GetPlace()); + sampled_gts.mutable_data({fg_num, kBoxDim}, context.GetPlace()); GatherBoxesLabels(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds, gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts); // Compute targets Tensor bbox_targets_single; bbox_targets_single.mutable_data(bbox_dim, context.GetPlace()); - BoxToDelta(boxes_num, sampled_boxes, sampled_gts, bbox_reg_weights, + BoxToDelta(fg_num, sampled_boxes, sampled_gts, nullptr, false, &bbox_targets_single); // Scale rois @@ -427,7 +404,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel { auto rpn_rois_lod = rpn_rois->lod().back(); auto gt_classes_lod = gt_classes->lod().back(); auto gt_boxes_lod = gt_boxes->lod().back(); - for (size_t i = 0; i < n; ++i) { + for (int i = 0; i < n; ++i) { Tensor rpn_rois_slice = rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]); Tensor gt_classes_slice = diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index fcdcafae7273afa6887ee531dfc37ef833b92d68..ebe6830eccd87a156768eb0d4b96220bcc9f4edc 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -311,8 +311,7 @@ class GenerateProposalsKernel : public framework::OpKernel { rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, context.GetPlace()); - rpn_roi_probs->mutable_data({scores->numel() / 4, 1}, - context.GetPlace()); + rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); Tensor bbox_deltas_swap, scores_swap; bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, @@ -421,7 +420,7 @@ class GenerateProposalsKernel : public framework::OpKernel { CPUGather(ctx, proposals, keep, &bbox_sel); CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { - return std::make_pair(bbox_sel, scores_sel); + return std::make_pair(bbox_sel, scores_filter); } Tensor keep_nms = NMS(ctx, &bbox_sel, &scores_filter, nms_thresh, eta); diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 177ff7cf187bc9daf69889e99ca57ae18766de90..88757f25cd9a5789758640de2d9cae0b12350b25 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -46,156 +47,219 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("DistMat"); PADDLE_ENFORCE_EQ(in_dims.size(), 2, "The rank of Input(DistMat) must be 2."); + + ctx->SetOutputDim("LocationIndex", {-1}); + ctx->SetOutputDim("ScoreIndex", {-1}); + ctx->SetOutputDim("TargetLabel", {-1, 1}); + ctx->SetOutputDim("TargetBBox", {-1, 4}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("DistMat")->type()), + platform::CPUPlace()); } }; template class RpnTargetAssignKernel : public framework::OpKernel { public: + void Compute(const framework::ExecutionContext& context) const override { + auto* anchor_t = context.Input("Anchor"); // (H*W*A) * 4 + auto* gt_bbox_t = context.Input("GtBox"); + auto* dist_t = context.Input("DistMat"); + + auto* loc_index_t = context.Output("LocationIndex"); + auto* score_index_t = context.Output("ScoreIndex"); + auto* tgt_bbox_t = context.Output("TargetBBox"); + auto* tgt_lbl_t = context.Output("TargetLabel"); + + auto lod = dist_t->lod().back(); + int64_t batch_num = static_cast(lod.size() - 1); + int64_t anchor_num = dist_t->dims()[1]; + PADDLE_ENFORCE_EQ(anchor_num, anchor_t->dims()[0]); + + int rpn_batch_size = context.Attr("rpn_batch_size_per_im"); + float pos_threshold = context.Attr("rpn_positive_overlap"); + float neg_threshold = context.Attr("rpn_negative_overlap"); + float fg_fraction = context.Attr("fg_fraction"); + + int fg_num_per_batch = static_cast(rpn_batch_size * fg_fraction); + + int64_t max_num = batch_num * anchor_num; + auto place = context.GetPlace(); + + tgt_bbox_t->mutable_data({max_num, 4}, place); + auto* loc_index = loc_index_t->mutable_data({max_num}, place); + auto* score_index = score_index_t->mutable_data({max_num}, place); + + Tensor tmp_tgt_lbl; + auto* tmp_lbl_data = tmp_tgt_lbl.mutable_data({max_num}, place); + auto& dev_ctx = context.device_context(); + math::SetConstant iset; + iset(dev_ctx, &tmp_tgt_lbl, static_cast(-1)); + + std::random_device rnd; + std::minstd_rand engine; + int seed = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + engine.seed(seed); + + int fg_num = 0; + int bg_num = 0; + for (int i = 0; i < batch_num; ++i) { + Tensor dist = dist_t->Slice(lod[i], lod[i + 1]); + Tensor gt_bbox = gt_bbox_t->Slice(lod[i], lod[i + 1]); + auto fg_bg_gt = SampleFgBgGt(dev_ctx, dist, pos_threshold, neg_threshold, + rpn_batch_size, fg_num_per_batch, engine, + tmp_lbl_data + i * anchor_num); + + int cur_fg_num = fg_bg_gt[0].size(); + int cur_bg_num = fg_bg_gt[1].size(); + std::transform(fg_bg_gt[0].begin(), fg_bg_gt[0].end(), loc_index, + [i, anchor_num](int d) { return d + i * anchor_num; }); + memcpy(score_index, loc_index, cur_fg_num * sizeof(int)); + std::transform(fg_bg_gt[1].begin(), fg_bg_gt[1].end(), + score_index + cur_fg_num, + [i, anchor_num](int d) { return d + i * anchor_num; }); + + // get target bbox deltas + if (cur_fg_num) { + Tensor fg_gt; + T* gt_data = fg_gt.mutable_data({cur_fg_num, 4}, place); + Tensor tgt_bbox = tgt_bbox_t->Slice(fg_num, fg_num + cur_fg_num); + T* tgt_data = tgt_bbox.data(); + Gather(anchor_t->data(), 4, + reinterpret_cast(&fg_bg_gt[0][0]), cur_fg_num, + tgt_data); + Gather(gt_bbox.data(), 4, reinterpret_cast(&fg_bg_gt[2][0]), + cur_fg_num, gt_data); + BoxToDelta(cur_fg_num, tgt_bbox, fg_gt, nullptr, false, &tgt_bbox); + } + + loc_index += cur_fg_num; + score_index += cur_fg_num + cur_bg_num; + fg_num += cur_fg_num; + bg_num += cur_bg_num; + } + + int lbl_num = fg_num + bg_num; + PADDLE_ENFORCE_LE(fg_num, max_num); + PADDLE_ENFORCE_LE(lbl_num, max_num); + + tgt_bbox_t->Resize({fg_num, 4}); + loc_index_t->Resize({fg_num}); + score_index_t->Resize({lbl_num}); + auto* lbl_data = tgt_lbl_t->mutable_data({lbl_num, 1}, place); + Gather(tmp_lbl_data, 1, score_index_t->data(), lbl_num, + lbl_data); + } + + private: void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max, const int row, const int col, const float pos_threshold, - const float neg_threshold, int64_t* target_label_data, + const float neg_threshold, int64_t* target_label, std::vector* fg_inds, std::vector* bg_inds) const { - int fg_offset = fg_inds->size(); - int bg_offset = bg_inds->size(); + float epsilon = 0.0001; for (int64_t i = 0; i < row; ++i) { const T* v = dist_data + i * col; - T max_dist = *std::max_element(v, v + col); + T max = *std::max_element(v, v + col); for (int64_t j = 0; j < col; ++j) { - T val = dist_data[i * col + j]; - if (val == max_dist) target_label_data[j] = 1; + if (std::abs(max - v[j]) < epsilon) { + target_label[j] = 1; + } } } - // Pick the fg/bg and count the number + // Pick the fg/bg + const T* anchor_to_gt_max_data = anchor_to_gt_max.data(); for (int64_t j = 0; j < col; ++j) { - if (anchor_to_gt_max.data()[j] > pos_threshold) { - target_label_data[j] = 1; - } else if (anchor_to_gt_max.data()[j] < neg_threshold) { - target_label_data[j] = 0; + if (anchor_to_gt_max_data[j] >= pos_threshold) { + target_label[j] = 1; + } else if (anchor_to_gt_max_data[j] < neg_threshold) { + target_label[j] = 0; } - if (target_label_data[j] == 1) { - fg_inds->push_back(fg_offset + j); - } else if (target_label_data[j] == 0) { - bg_inds->push_back(bg_offset + j); + if (target_label[j] == 1) { + fg_inds->push_back(j); + } else if (target_label[j] == 0) { + bg_inds->push_back(j); } } } - void ReservoirSampling(const int num, const int offset, - std::minstd_rand engine, + void ReservoirSampling(const int num, std::minstd_rand engine, std::vector* inds) const { std::uniform_real_distribution uniform(0, 1); - const int64_t size = static_cast(inds->size() - offset); - if (size > num) { - for (int64_t i = num; i < size; ++i) { + size_t len = inds->size(); + if (len > static_cast(num)) { + for (size_t i = num; i < len; ++i) { int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < num) - std::iter_swap(inds->begin() + rng_ind + offset, - inds->begin() + i + offset); + std::iter_swap(inds->begin() + rng_ind, inds->begin() + i); } + inds->resize(num); } } - void RpnTargetAssign(const framework::ExecutionContext& ctx, - const Tensor& dist, const float pos_threshold, - const float neg_threshold, const int rpn_batch_size, - const int fg_num, std::minstd_rand engine, - std::vector* fg_inds, std::vector* bg_inds, - int64_t* target_label_data) const { + // std::vector> RpnTargetAssign( + std::vector> SampleFgBgGt( + const platform::CPUDeviceContext& ctx, const Tensor& dist, + const float pos_threshold, const float neg_threshold, + const int rpn_batch_size, const int fg_num, std::minstd_rand engine, + int64_t* target_label) const { auto* dist_data = dist.data(); - int64_t row = dist.dims()[0]; - int64_t col = dist.dims()[1]; - int fg_offset = fg_inds->size(); - int bg_offset = bg_inds->size(); + int row = dist.dims()[0]; + int col = dist.dims()[1]; + + std::vector fg_inds; + std::vector bg_inds; + std::vector gt_inds; // Calculate the max IoU between anchors and gt boxes - Tensor anchor_to_gt_max; - anchor_to_gt_max.mutable_data( - framework::make_ddim({static_cast(col), 1}), - platform::CPUPlace()); - auto& place = *ctx.template device_context() - .eigen_device(); - auto x = EigenMatrix::From(dist); - auto x_col_max = EigenMatrix::From(anchor_to_gt_max); - x_col_max.device(place) = - x.maximum(Eigen::DSizes(0)) - .reshape(Eigen::DSizes(static_cast(col), 1)); + // Map from anchor to gt box that has highest overlap + auto place = ctx.GetPlace(); + Tensor anchor_to_gt_max, anchor_to_gt_argmax; + anchor_to_gt_max.mutable_data({col}, place); + int* argmax = anchor_to_gt_argmax.mutable_data({col}, place); + + auto x = framework::EigenMatrix::From(dist); + auto x_col_max = framework::EigenVector::Flatten(anchor_to_gt_max); + auto x_col_argmax = + framework::EigenVector::Flatten(anchor_to_gt_argmax); + x_col_max = x.maximum(Eigen::DSizes(0)); + x_col_argmax = x.argmax(0).template cast(); + // Follow the Faster RCNN's implementation ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold, - neg_threshold, target_label_data, fg_inds, bg_inds); + neg_threshold, target_label, &fg_inds, &bg_inds); // Reservoir Sampling - ReservoirSampling(fg_num, fg_offset, engine, fg_inds); - int bg_num = rpn_batch_size - (fg_inds->size() - fg_offset); - ReservoirSampling(bg_num, bg_offset, engine, bg_inds); - } + ReservoirSampling(fg_num, engine, &fg_inds); + int fg_num2 = static_cast(fg_inds.size()); + int bg_num = rpn_batch_size - fg_num2; + ReservoirSampling(bg_num, engine, &bg_inds); - void Compute(const framework::ExecutionContext& context) const override { - auto* dist = context.Input("DistMat"); - auto* loc_index = context.Output("LocationIndex"); - auto* score_index = context.Output("ScoreIndex"); - auto* tgt_lbl = context.Output("TargetLabel"); - - auto col = dist->dims()[1]; - int64_t n = dist->lod().size() == 0UL - ? 1 - : static_cast(dist->lod().back().size() - 1); - if (dist->lod().size()) { - PADDLE_ENFORCE_EQ(dist->lod().size(), 1UL, - "Only support 1 level of LoD."); + gt_inds.reserve(fg_num2); + for (int i = 0; i < fg_num2; ++i) { + gt_inds.emplace_back(argmax[fg_inds[i]]); } - int rpn_batch_size = context.Attr("rpn_batch_size_per_im"); - float pos_threshold = context.Attr("rpn_positive_overlap"); - float neg_threshold = context.Attr("rpn_negative_overlap"); - float fg_fraction = context.Attr("fg_fraction"); - - int fg_num = static_cast(rpn_batch_size * fg_fraction); - - int64_t* target_label_data = - tgt_lbl->mutable_data({n * col, 1}, context.GetPlace()); + std::vector> fg_bg_gt; + fg_bg_gt.emplace_back(fg_inds); + fg_bg_gt.emplace_back(bg_inds); + fg_bg_gt.emplace_back(gt_inds); - auto& dev_ctx = context.device_context(); - math::SetConstant iset; - iset(dev_ctx, tgt_lbl, static_cast(-1)); - - std::vector fg_inds; - std::vector bg_inds; - std::random_device rnd; - std::minstd_rand engine; - int seed = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - engine.seed(seed); - - if (n == 1) { - RpnTargetAssign(context, *dist, pos_threshold, neg_threshold, - rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds, - target_label_data); - } else { - auto lod = dist->lod().back(); - for (size_t i = 0; i < lod.size() - 1; ++i) { - Tensor one_ins = dist->Slice(lod[i], lod[i + 1]); - RpnTargetAssign(context, one_ins, pos_threshold, neg_threshold, - rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds, - target_label_data + i * col); - } - } - int* loc_index_data = loc_index->mutable_data( - {static_cast(fg_inds.size())}, context.GetPlace()); - int* score_index_data = score_index->mutable_data( - {static_cast(fg_inds.size() + bg_inds.size())}, - context.GetPlace()); - memcpy(loc_index_data, reinterpret_cast(&fg_inds[0]), - fg_inds.size() * sizeof(int)); - memcpy(score_index_data, reinterpret_cast(&fg_inds[0]), - fg_inds.size() * sizeof(int)); - memcpy(score_index_data + fg_inds.size(), - reinterpret_cast(&bg_inds[0]), bg_inds.size() * sizeof(int)); + return fg_bg_gt; } }; class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { + AddInput("Anchor", + "(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4]."); + AddInput("GtBox", "(LoDTensor) input groud-truth bbox with shape [K, 4]."); AddInput( "DistMat", "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " @@ -241,12 +305,15 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "ScoreIndex", "(Tensor), The indexes of foreground and background anchors in all " "RPN anchors(The rest anchors are ignored). The shape of the " - "ScoreIndex is [F + B], F and B depend on the value of input " - "tensor and attributes."); - AddOutput("TargetLabel", - "(Tensor), The target labels of each anchor with shape " - "[K * M, 1], " - "K and M is the same as they are in DistMat."); + "ScoreIndex is [F + B], F and B are sampled foreground and backgroud " + " number."); + AddOutput("TargetBBox", + "(Tensor), The target bbox deltas with shape " + "[F, 4], F is the sampled foreground number."); + AddOutput( + "TargetLabel", + "(Tensor), The target labels of each anchor with shape " + "[F + B, 1], F and B are sampled foreground and backgroud number."); AddComment(R"DOC( This operator can be, for given the IoU between the ground truth bboxes and the anchors, to assign classification and regression targets to each prediction. diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index 50450b62f7b1c0b2b5abf01a43581a0e2d2cd01e..46e20285db6d7acd39dead3994409645adddf494 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -31,7 +31,7 @@ static inline int NumBlocks(const int N) { template __global__ void GPUROIPoolForward( - const int nthreads, const T* input_data, const int64_t* input_rois, + const int nthreads, const T* input_data, const T* input_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, int* roi_batch_id_data, T* output_data, int64_t* argmax_data) { @@ -43,7 +43,7 @@ __global__ void GPUROIPoolForward( int c = (i / pooled_width / pooled_height) % channels; int n = i / pooled_width / pooled_height / channels; - const int64_t* offset_input_rois = input_rois + n * kROISize; + const T* offset_input_rois = input_rois + n * kROISize; int roi_batch_ind = roi_batch_id_data[n]; int roi_start_w = round(offset_input_rois[0] * spatial_scale); int roi_start_h = round(offset_input_rois[1] * spatial_scale); @@ -93,7 +93,7 @@ __global__ void GPUROIPoolForward( template __global__ void GPUROIPoolBackward( - const int nthreads, const int64_t* input_rois, const T* output_grad, + const int nthreads, const T* input_rois, const T* output_grad, const int64_t* argmax_data, const int num_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, int* roi_batch_id_data, @@ -174,8 +174,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel { GPUROIPoolForward< T><<>>( - output_size, in->data(), rois->data(), spatial_scale, - channels, height, width, pooled_height, pooled_width, + output_size, in->data(), rois->data(), spatial_scale, channels, + height, width, pooled_height, pooled_width, roi_batch_id_list_gpu.data(), out->mutable_data(ctx.GetPlace()), argmax->mutable_data(ctx.GetPlace())); } @@ -228,7 +228,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { if (output_grad_size > 0) { GPUROIPoolBackward< T><<>>( - output_grad_size, rois->data(), out_grad->data(), + output_grad_size, rois->data(), out_grad->data(), argmax->data(), rois_num, spatial_scale, channels, height, width, pooled_height, pooled_width, roi_batch_id_list_gpu.data(), diff --git a/paddle/fluid/operators/roi_pool_op.h b/paddle/fluid/operators/roi_pool_op.h index c4f739b2c6b2d62ebebcc15fd627ebad040e7b3f..07de7c9f0e070cef7c6f38f8d564ab76910842db 100644 --- a/paddle/fluid/operators/roi_pool_op.h +++ b/paddle/fluid/operators/roi_pool_op.h @@ -72,7 +72,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel { T* output_data = out->mutable_data(ctx.GetPlace()); int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); - const int64_t* rois_data = rois->data(); + const T* rois_data = rois->data(); for (int n = 0; n < rois_num; ++n) { int roi_batch_id = roi_batch_id_data[n]; int roi_start_w = round(rois_data[0] * spatial_scale); @@ -171,7 +171,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { } } - const int64_t* rois_data = rois->data(); + const T* rois_data = rois->data(); const T* out_grad_data = out_grad->data(); const int64_t* argmax_data = argmax->data(); T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..58e48c228bb34814700fd0f7a3d62ef4b1a435dd --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_enumerate_op.h" + +namespace paddle { +namespace operators { + +class SequenceEnumerateOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of SequecceEnumerate operator should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(X) of SequenceEnumerate operator should not be null."); + + const auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ( + x_dims.size(), 2UL, + "Input(X) of SequenceEnumerate operator's rank should be 2."); + PADDLE_ENFORCE_EQ( + x_dims[1], 1UL, + "Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); + + const auto win_size = ctx->Attrs().Get("win_size"); + ctx->SetOutputDim("Out", {x_dims[0], win_size}); + ctx->ShareLoD("X", "Out"); + } +}; + +class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(2-D LoDTensor with the 2nd dimension equal to 1) " + "Input LoDTensor of SequenceEnumerate operator."); + AddOutput("Out", + "(2-D LoDTensor with the 2nd dimension equal to win_size) " + "Output LoDTensor of SequenceEnumerate operator."); + AddAttr("win_size", "(int) The enumerate sequence window size.") + .AddCustomChecker([](const int& win_size) { + PADDLE_ENFORCE(win_size >= 2, + "The window size should be not less than 2."); + }); + AddAttr("pad_value", "(int) The enumerate sequence padding value.") + .SetDefault(0); + AddComment(R"DOC( +Sequence Enumerate Operator. + +Generate a new sequence for the input index sequence, which enumerates all the +sub-sequences with length `win_size` of the input. +The enumerated sequence has the same 1st dimension with variable `input`, and +the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation. + +Examples: +Case 1: + Input: + X.lod = [[0, 3, 5]] + X.data = [[1], [2], [3], [4], [5]] + X.dims = [5, 1] + Attrs: + win_size = 2 + pad_value = 0 + Output: + Out.lod = [[0, 3, 5]] + Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]] + Out.dims = [5, 2] + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp, + ops::SequenceEnumerateOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_enumerate, + ops::SequenceEnumerateKernel, + ops::SequenceEnumerateKernel); diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_enumerate_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a615aa9a1ecd99c1f6995361f8c5ff0aa383 --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.cu @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/operators/sequence_enumerate_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void CalcOutPut(const T* in_data, const size_t* in_lod, + const size_t lod_len, const int64_t win_size, + const int64_t pad_value, T* out_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_lod[lod_len - 1]) { + int end_idx = 0; + // Get LoD interval of index + for (int i = 1; i < lod_len; ++i) { + if (index < in_lod[i]) { + end_idx = in_lod[i]; + break; + } + } + for (size_t i = 0; i < win_size; ++i) { + int word_pos = index + i; + out_data[index * win_size + i] = + word_pos < end_idx ? in_data[word_pos] : pad_value; + } + } +} + +template +class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int win_size = context.Attr("win_size"); + int pad_value = context.Attr("pad_value"); + + auto in_dims = in->dims(); + auto in_lod = in->lod(); + + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + /* Generate enumerate sequence set */ + auto stream = context.cuda_device_context().stream(); + auto lod0 = in_lod[0]; + auto in_len = in->numel(); + auto in_data = in->data(); + auto out_data = out->mutable_data(context.GetPlace()); + // Copy LoD to GPU + const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); + // Calc output tensor + CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + sequence_enumerate, + paddle::operators::SequenceEnumerateOpCUDAKernel, + paddle::operators::SequenceEnumerateOpCUDAKernel); diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_enumerate_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dc18d9b2071303377505155476b87ed029eaf986 --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.h @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using LoDTensor = framework::LoDTensor; + +template +class SequenceEnumerateKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int win_size = context.Attr("win_size"); + int pad_value = context.Attr("pad_value"); + + auto in_dims = in->dims(); + auto in_lod = in->lod(); + + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + // Generate enumerate sequence set + auto lod0 = in_lod[0]; + auto in_data = in->data(); + auto out_data = out->mutable_data(context.GetPlace()); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { + for (int word_idx = 0; word_idx < win_size; ++word_idx) { + size_t word_pos = idx + word_idx; + out_data[win_size * idx + word_idx] = + word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value; + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3ec20ad7e5b4d8104e7228da9f8e7f5409417301..2cc26da013f59f5b7ee1747d57baca9c1c0efe2c 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -16,9 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memory.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/rw_lock.h" -#endif namespace paddle { namespace platform { @@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -class CudnnHolder { - public: - CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); - } - - cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } - - void RunFunc(const std::function& cudnn_func, - size_t required_workspace_len) { - std::lock_guard lock(mtx_); - if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len); - } - cudnn_func(workspace_); - } - - ~CudnnHolder() { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - if (workspace_ != nullptr) { - paddle::memory::Free(place_, workspace_); - } - } - - private: - void ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= workspace_len_) { - return; - } - void* new_workspace = paddle::memory::Alloc(place_, required_workspace_len); - if (workspace_ != nullptr) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - paddle::memory::Free(place_, workspace_); - } - workspace_ = new_workspace; - workspace_len_ = required_workspace_len; - } - - cudnnHandle_t cudnn_handle_; - void* workspace_; - size_t workspace_len_; - - const cudaStream_t* stream_; // not owned; - const CUDAPlace place_; - - std::mutex mtx_; -}; - -CUDADeviceContext::CUDADeviceContext(CUDAPlace place) - : place_(place), cudnn_holder_(nullptr) { +CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); @@ -209,7 +154,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); if (dynload::HasCUDNN()) { - cudnn_holder_.reset(new CudnnHolder(&stream_, place)); + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); + } else { + cudnn_handle_ = nullptr; } } @@ -217,6 +165,9 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + if (cudnn_handle_ != nullptr) { + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); + } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { return cublas_handle_; } -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { - return cudnn_holder_->cudnn_handle(); -} - -void CUDADeviceContext::RunCudnnFuncWithWorkspace( - const std::function& cudnn_func, size_t workspace_len) const { - cudnn_holder_->RunFunc(cudnn_func, workspace_len); -} +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3ed49fc4233d4c0cd6cc16319eda08480ab9b434..b97dad20db0b003b4886b7c7cfd1c8de8bf44ab9 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -69,7 +69,6 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CudnnHolder; class CUDADeviceContext : public DeviceContext { public: @@ -97,11 +96,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; - /*! \brief Run a cudnn function with the workspace provided by - * CUDADeviceContext */ - void RunCudnnFuncWithWorkspace(const std::function& cudnn_func, - size_t workspace_len) const; - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; @@ -117,8 +111,8 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - std::unique_ptr cudnn_holder_; cudaStream_t stream_; + cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; int compute_capability; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 5757b2798e43dc70b406462a74b4f74eedcf56fa..1bc1dbbecaccd328d84cd3364a50c8f828d823c0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -145,26 +145,23 @@ def rpn_target_assign(loc, """ helper = LayerHelper('rpn_target_assign', **locals()) - # 1. Compute the regression target bboxes - target_bbox = box_coder( - prior_box=anchor_box, - prior_box_var=anchor_var, - target_box=gt_box, - code_type='encode_center_size', - box_normalized=False) - # 2. Compute overlaps between the prior boxes and the gt boxes overlaps + # Compute overlaps between the prior boxes and the gt boxes overlaps iou = iou_similarity(x=gt_box, y=anchor_box) - # 3. Assign target label to anchors - loc_index = helper.create_tmp_variable(dtype=anchor_box.dtype) - score_index = helper.create_tmp_variable(dtype=anchor_box.dtype) - target_label = helper.create_tmp_variable(dtype=anchor_box.dtype) + # Assign target label to anchors + loc_index = helper.create_tmp_variable(dtype='int32') + score_index = helper.create_tmp_variable(dtype='int32') + target_label = helper.create_tmp_variable(dtype='int64') + target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) helper.append_op( type="rpn_target_assign", - inputs={'DistMat': iou}, + inputs={'Anchor': anchor_box, + 'GtBox': gt_box, + 'DistMat': iou}, outputs={ 'LocationIndex': loc_index, 'ScoreIndex': score_index, - 'TargetLabel': target_label + 'TargetLabel': target_label, + 'TargetBBox': target_bbox, }, attrs={ 'rpn_batch_size_per_im': rpn_batch_size_per_im, @@ -173,16 +170,16 @@ def rpn_target_assign(loc, 'fg_fraction': fg_fraction }) - # 4. Reshape and gather the target entry - scores = nn.reshape(x=scores, shape=(-1, 2)) - loc = nn.reshape(x=loc, shape=(-1, 4)) - target_label = nn.reshape(x=target_label, shape=(-1, 1)) - target_bbox = nn.reshape(x=target_bbox, shape=(-1, 4)) + loc_index.stop_gradient = True + score_index.stop_gradient = True + target_label.stop_gradient = True + target_bbox.stop_gradient = True + scores = nn.reshape(x=scores, shape=(-1, 1)) + loc = nn.reshape(x=loc, shape=(-1, 4)) predicted_scores = nn.gather(scores, score_index) predicted_location = nn.gather(loc, loc_index) - target_label = nn.gather(target_label, score_index) - target_bbox = nn.gather(target_bbox, loc_index) + return predicted_scores, predicted_location, target_label, target_bbox diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0ecfc958a3b89c85ef00574d630042d410c3fa0a..a0d92fd1462acb18cdb2463b51138c9ff33b08a8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -111,6 +111,7 @@ __all__ = [ 'stack', 'pad2d', 'unstack', + 'sequence_enumerate', ] @@ -5823,6 +5824,51 @@ def flatten(x, axis=1, name=None): return out +def sequence_enumerate(input, win_size, pad_value=0, name=None): + """ + Generate a new sequence for the input index sequence, which enumerates all the + sub-sequences with length `win_size` of the input. + The enumerated sequence has the same 1st dimension with variable `input`, and + the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation. + + Examples: + Case 1: + Input: + X.lod = [[0, 3, 5]] + X.data = [[1], [2], [3], [4], [5]] + X.dims = [5, 1] + Attrs: + win_size = 2 + pad_value = 0 + Output: + Out.lod = [[0, 3, 5]] + Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]] + Out.dims = [5, 2] + + Args: + input (Variable): The input variable which is a index sequence. + win_size (int): The window size for enumerating all sub-sequences. + pad_value (int): The padding value, default 0. + + Returns: + Variable: The enumerate sequence variable which is a LoDTensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1) + out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) + """ + helper = LayerHelper('sequence_enumerate', **locals()) + out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True) + helper.append_op( + type='sequence_enumerate', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'win_size': win_size, + 'pad_value': pad_value}) + + def sequence_mask(x, maxlen=None, dtype='int64', name=None): """ **SequenceMask Layer** @@ -5902,6 +5948,7 @@ def stack(x, axis=0): helper.append_op( type='stack', inputs={'X': x}, outputs={'Y': out}, attrs={'axis': axis}) + return out diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index ec0bf3ff8d64345111537780aaa5367ed0e1f8ff..e2564763d19d180f7c6933429dddf58c77be7bb8 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -281,7 +281,7 @@ class TestRpnTargetAssign(unittest.TestCase): gt_box = layers.data( name='gt_box', shape=[4], lod_level=1, dtype='float32') - predicted_scores, predicted_location, target_label, target_bbox = layers.rpn_target_assign( + pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( loc=loc, scores=scores, anchor_box=anchor_box, @@ -292,15 +292,13 @@ class TestRpnTargetAssign(unittest.TestCase): rpn_positive_overlap=0.7, rpn_negative_overlap=0.3) - self.assertIsNotNone(predicted_scores) - self.assertIsNotNone(predicted_location) - self.assertIsNotNone(target_label) - self.assertIsNotNone(target_bbox) - assert predicted_scores.shape[1] == 2 - assert predicted_location.shape[1] == 4 - assert predicted_location.shape[1] == target_bbox.shape[1] - - print(str(program)) + self.assertIsNotNone(pred_scores) + self.assertIsNotNone(pred_loc) + self.assertIsNotNone(tgt_lbl) + self.assertIsNotNone(tgt_bbox) + assert pred_scores.shape[1] == 1 + assert pred_loc.shape[1] == 4 + assert pred_loc.shape[1] == tgt_bbox.shape[1] class TestGenerateProposals(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index 764f83b534c8a183dbf21511f0b05741c13c9528..36ebc8fb6ea9efdcd1807f5c8917ab1428b3381e 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -37,7 +37,7 @@ def fusion_gru( h0, wh, np.zeros( - (1, wh.shape[1]), dtype='float64'), + (1, wh.shape[1]), dtype='float32'), is_reverse, act_state, act_gate) @@ -62,15 +62,15 @@ class TestFusionGRUOp(OpTest): T = sum(self.lod[0]) N = len(self.lod[0]) - x = np.random.rand(T, self.M).astype('float64') - wx = np.random.rand(self.M, 3 * self.D).astype('float64') - wh = np.random.rand(self.D, 3 * self.D).astype('float64') + x = np.random.rand(T, self.M).astype('float32') + wx = np.random.rand(self.M, 3 * self.D).astype('float32') + wh = np.random.rand(self.D, 3 * self.D).astype('float32') bias = np.random.rand( - 1, 3 * self.D).astype('float64') if self.with_bias else np.zeros( - (1, 3 * self.D), dtype='float64') + 1, 3 * self.D).astype('float32') if self.with_bias else np.zeros( + (1, 3 * self.D), dtype='float32') h0 = np.random.rand( - N, self.D).astype('float64') if self.with_h0 else np.zeros( - (N, self.D), dtype='float64') + N, self.D).astype('float32') if self.with_h0 else np.zeros( + (N, self.D), dtype='float32') _, _, _, hidden = fusion_gru( x, self.lod, h0, wx, wh, bias, self.is_reverse, @@ -93,7 +93,9 @@ class TestFusionGRUOp(OpTest): } def test_check_output(self): - self.check_output(atol=1e-8) + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output() class TestFusionGRUOpNoInitial(TestFusionGRUOp): diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 5805bdf461998e90611dec05b079cd55feda520d..1f1eb37667e304351a6a85edde09e7da32cf1630 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -114,7 +114,9 @@ class TestFusionLSTMOp(OpTest): } def test_check_output(self): - self.check_output() + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output() class TestFusionLSTMOpInit(TestFusionLSTMOp): diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py index ce766fffbce98a6a2cee4c508d6db85ee0163401..6dc101b6dad8813893c6a891da0e16f952bb4c2d 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels.py @@ -177,8 +177,8 @@ def _box_to_delta(ex_boxes, gt_boxes, weights): dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] - dw = (np.log(gt_w / ex_w)) / ex_w / weights[2] - dh = (np.log(gt_h / ex_h)) / ex_h / weights[3] + dw = (np.log(gt_w / ex_w)) / weights[2] + dh = (np.log(gt_h / ex_h)) / weights[3] targets = np.vstack([dx, dy, dw, dh]).transpose() return targets diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ecdf32524afb1357b192ce14674b7073972dee9f..bc4d364c74c6cb6b8f0df59e7ede77e6271f4b96 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -549,6 +549,13 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_sequence_enumerate(self): + program = Program() + with program_guard(program): + x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1) + out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py index ed7f467835f32242a9650f226b4a5ad9d6d87af4..ad4cd2e803bfae4c3fbc04503331b9a786b25d17 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_pool_op.py @@ -61,7 +61,7 @@ class TestROIPoolOp(OpTest): for i in range(self.rois_num): roi = self.rois[i] - roi_batch_id = roi[0] + roi_batch_id = int(roi[0]) roi_start_w = int(cpt.round(roi[1] * self.spatial_scale)) roi_start_h = int(cpt.round(roi[2] * self.spatial_scale)) roi_end_w = int(cpt.round(roi[3] * self.spatial_scale)) @@ -125,7 +125,7 @@ class TestROIPoolOp(OpTest): roi = [bno, x1, y1, x2, y2] rois.append(roi) self.rois_num = len(rois) - self.rois = np.array(rois).astype("int64") + self.rois = np.array(rois).astype("float32") def setUp(self): self.op_type = "roi_pool" diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index 08c462d9036cacab81dab7c9ea16664c9159479f..bd548009b3ada9512e4b5f7d7b61b67b0717a39b 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -18,12 +18,17 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest +from test_anchor_generator_op import anchor_generator_in_python +from test_generate_proposal_labels import _generate_groundtruth +from test_generate_proposal_labels import _bbox_overlaps, _box_to_delta -def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, - rpn_negative_overlap, fg_fraction): - iou = np.transpose(iou) +def rpn_target_assign(gt_anchor_iou, rpn_batch_size_per_im, + rpn_positive_overlap, rpn_negative_overlap, fg_fraction): + iou = np.transpose(gt_anchor_iou) anchor_to_gt_max = iou.max(axis=1) + anchor_to_gt_argmax = iou.argmax(axis=1) + gt_to_anchor_argmax = iou.argmax(axis=0) gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])] anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0] @@ -42,59 +47,113 @@ def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap, num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1) bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] + tgt_lbl[bg_inds] = 0 if len(bg_inds) > num_bg: enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] tgt_lbl[enable_inds] = 0 bg_inds = np.where(tgt_lbl == 0)[0] + tgt_lbl[bg_inds] = 0 loc_index = fg_inds score_index = np.hstack((fg_inds, bg_inds)) tgt_lbl = np.expand_dims(tgt_lbl, axis=1) - return loc_index, score_index, tgt_lbl + + gt_inds = anchor_to_gt_argmax[fg_inds] + + return loc_index, score_index, tgt_lbl, gt_inds + + +def get_anchor(n, c, h, w): + input_feat = np.random.random((n, c, h, w)).astype('float32') + anchors, _ = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=[32., 64.], + aspect_ratios=[0.5, 1.0], + variances=[1.0, 1.0, 1.0, 1.0], + stride=[16.0, 16.0], + offset=0.5) + return anchors + + +def rpn_blob(anchor, gt_boxes, iou, lod, rpn_batch_size_per_im, + rpn_positive_overlap, rpn_negative_overlap, fg_fraction): + + loc_indexes = [] + score_indexes = [] + tmp_tgt_labels = [] + tgt_bboxes = [] + anchor_num = anchor.shape[0] + + batch_size = len(lod) - 1 + for i in range(batch_size): + b, e = lod[i], lod[i + 1] + iou_slice = iou[b:e, :] + bboxes_slice = gt_boxes[b:e, :] + + loc_idx, score_idx, tgt_lbl, gt_inds = rpn_target_assign( + iou_slice, rpn_batch_size_per_im, rpn_positive_overlap, + rpn_negative_overlap, fg_fraction) + + fg_bboxes = bboxes_slice[gt_inds] + fg_anchors = anchor[loc_idx] + box_deltas = _box_to_delta(fg_anchors, fg_bboxes, [1., 1., 1., 1.]) + + if i == 0: + loc_indexes = loc_idx + score_indexes = score_idx + tmp_tgt_labels = tgt_lbl + tgt_bboxes = box_deltas + else: + loc_indexes = np.concatenate( + [loc_indexes, loc_idx + i * anchor_num]) + score_indexes = np.concatenate( + [score_indexes, score_idx + i * anchor_num]) + tmp_tgt_labels = np.concatenate([tmp_tgt_labels, tgt_lbl]) + tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + + tgt_labels = tmp_tgt_labels[score_indexes] + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels class TestRpnTargetAssignOp(OpTest): def setUp(self): - iou = np.random.random((10, 8)).astype("float32") - self.op_type = "rpn_target_assign" - self.inputs = {'DistMat': iou} - self.attrs = { - 'rpn_batch_size_per_im': 256, - 'rpn_positive_overlap': 0.95, - 'rpn_negative_overlap': 0.3, - 'fg_fraction': 0.25, - 'fix_seed': True - } - loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 256, 0.95, 0.3, - 0.25) - self.outputs = { - 'LocationIndex': loc_index, - 'ScoreIndex': score_index, - 'TargetLabel': tgt_lbl, - } + n, c, h, w = 2, 4, 14, 14 + anchor = get_anchor(n, c, h, w) + gt_num = 10 + anchor = anchor.reshape(-1, 4) + anchor_num = anchor.shape[0] - def test_check_output(self): - self.check_output() + im_shapes = [[64, 64], [64, 64]] + gt_box, lod = _generate_groundtruth(im_shapes, 3, 4) + bbox = np.vstack([v['boxes'] for v in gt_box]) + iou = _bbox_overlaps(bbox, anchor) + + anchor = anchor.astype('float32') + bbox = bbox.astype('float32') + iou = iou.astype('float32') + + loc_index, score_index, tgt_bbox, tgt_lbl = rpn_blob( + anchor, bbox, iou, [0, 4, 8], 25600, 0.95, 0.03, 0.25) -class TestRpnTargetAssignOp2(OpTest): - def setUp(self): - iou = np.random.random((10, 20)).astype("float32") self.op_type = "rpn_target_assign" - self.inputs = {'DistMat': iou} + self.inputs = { + 'Anchor': anchor, + 'GtBox': (bbox, [[4, 4]]), + 'DistMat': (iou, [[4, 4]]), + } self.attrs = { - 'rpn_batch_size_per_im': 128, - 'rpn_positive_overlap': 0.5, - 'rpn_negative_overlap': 0.5, - 'fg_fraction': 0.5, + 'rpn_batch_size_per_im': 25600, + 'rpn_positive_overlap': 0.95, + 'rpn_negative_overlap': 0.03, + 'fg_fraction': 0.25, 'fix_seed': True } - loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 128, 0.5, 0.5, - 0.5) self.outputs = { - 'LocationIndex': loc_index, - 'ScoreIndex': score_index, - 'TargetLabel': tgt_lbl, + 'LocationIndex': loc_index.astype('int32'), + 'ScoreIndex': score_index.astype('int32'), + 'TargetBBox': tgt_bbox.astype('float32'), + 'TargetLabel': tgt_lbl.astype('int64'), } def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9814ec0a15e1803b356f300d378c31e57ba36c09 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -0,0 +1,105 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_enumerate(input_seq, in_lod, win_size, pad_value): + lod0 = [0] + for i in range(0, len(in_lod[0])): + lod0.append(lod0[i] + in_lod[0][i]) + out_seq = [] + for i in range(0, len(lod0) - 1): + for idx in range(lod0[i], lod0[i + 1]): + single_seq = [] + for word_idx in range(win_size): + word_pos = idx + word_idx + dat = input_seq[word_pos] if word_pos < lod0[i+1] \ + else pad_value + single_seq.append(dat) + out_seq.append(single_seq) + return out_seq + + +class TestSequenceEnumerateOp(OpTest): + def setUp(self): + self.op_type = "sequence_enumerate" + self.init_test_case() + self.inputs = {'X': (self.in_seq, self.lod)} + self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value} + self.outputs = {'Out': (self.out_seq, self.lod)} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 2 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64") + self.lod = [[9, 4, 11, 6]] + self.win_size = 2 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int64") + + +class TestSequenceEnumerateOpLargeWinSize(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 5 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 30 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 5 + self.pad_value = 5 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index bddeb6617c1743de946b3c5b4b0a465d85f35ce3..8a330e0dee7eda02d0858446778363f2235a3d73 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1096,7 +1096,8 @@ class DistributeTranspiler(object): self.table_name] zero_dim = int( - math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) + math.ceil(origin_param_var.shape[0] / float( + len(self.pserver_endpoints)))) table_shape = list(origin_param_var.shape) table_shape[0] = zero_dim