提交 6e2e0ec8 编写于 作者: D Dang Qingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into quantize_transpiler_update

......@@ -76,33 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85
## Installation
It is recommended to check out the
[Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/docker_install_en.html)
before looking into the
[build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html).
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website.
## Documentation
We provide [English](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) and
[Chinese](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) documentation.
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation.
- [Deep Learning 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html)
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/cluster/index_en.html)
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters.
- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/cluster/multi_cluster/k8s_en.html)
You can also run distributed training jobs on Kubernetes clusters.
- [Python API](http://www.paddlepaddle.org/docs/develop/api/en/overview.html)
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html)
Our new API enables much shorter programs.
- [How to Contribute](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/dev/contribute_to_paddle_en.html)
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
We appreciate your contributions!
......
......@@ -172,6 +172,7 @@ paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'],
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None))
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
......@@ -56,76 +56,5 @@ struct RWLock {
};
#endif
class RWLockGuard {
public:
enum Status { kUnLock, kWRLock, kRDLock };
RWLockGuard(RWLock* rw_lock, Status init_status)
: lock_(rw_lock), status_(Status::kUnLock) {
switch (init_status) {
case Status::kRDLock: {
RDLock();
break;
}
case Status::kWRLock: {
WRLock();
break;
}
case Status::kUnLock: {
break;
}
}
}
void WRLock() {
switch (status_) {
case Status::kUnLock: {
lock_->WRLock();
status_ = Status::kWRLock;
break;
}
case Status::kWRLock: {
break;
}
case Status::kRDLock: {
PADDLE_THROW(
"Please unlock read lock first before invoking write lock.");
break;
}
}
}
void RDLock() {
switch (status_) {
case Status::kUnLock: {
lock_->RDLock();
status_ = Status::kRDLock;
break;
}
case Status::kRDLock: {
break;
}
case Status::kWRLock: {
PADDLE_THROW(
"Please unlock write lock first before invoking read lock.");
break;
}
}
}
void UnLock() {
if (status_ != Status::kUnLock) {
lock_->UNLock();
status_ = Status::kUnLock;
}
}
~RWLockGuard() { UnLock(); }
private:
RWLock* lock_;
Status status_;
};
} // namespace framework
} // namespace paddle
......@@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
output_channels / groups * output_height * output_width * output_depth;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
......@@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......@@ -311,7 +314,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc, filter_algo, &tmp_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
......@@ -319,15 +326,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc,
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + i * group_offset_in));
}
}
// ------------------- cudnn conv backward filter ---------------------
......@@ -335,17 +339,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc,
input_data + i * group_offset_in, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc,
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + i * group_offset_filter));
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......
......@@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
conv_desc.descriptor<T>(paddings, strides, dilations);
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
if (user_workspace_size > 0) {
......@@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward ---------------------
int input_offset = input->numel() / input->dims()[0] / groups;
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......@@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
}
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
int input_offset = input->numel() / input->dims()[0] / groups;
......@@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
}
}
......@@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g));
};
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + filter_offset * g));
}
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
/*
* transform that computes target bounding-box regression deltas
* given proposal boxes and ground-truth boxes.
*/
template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights,
const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
auto trg = framework::EigenTensor<T, 2>::From(*box_delta);
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;
for (int64_t i = 0; i < box_num; ++i) {
ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + (normalized == false);
ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + (normalized == false);
ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w;
ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h;
gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + (normalized == false);
gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + (normalized == false);
gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w;
gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h;
trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w;
trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h;
trg(i, 2) = std::log(gt_w / ex_w);
trg(i, 3) = std::log(gt_h / ex_h);
if (weights) {
trg(i, 0) = trg(i, 0) / weights[0];
trg(i, 1) = trg(i, 1) / weights[1];
trg(i, 2) = trg(i, 2) / weights[2];
trg(i, 3) = trg(i, 3) / weights[3];
}
}
}
template <typename T>
void Gather(const T* in, const int in_stride, const int* index, const int num,
T* out) {
const int stride_bytes = in_stride * sizeof(T);
for (int i = 0; i < num; ++i) {
int id = index[i];
memcpy(out + i * in_stride, in + id * in_stride, stride_bytes);
}
}
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -133,31 +134,6 @@ void BboxOverlaps(const Tensor& r_boxes, const Tensor& c_boxes,
}
}
template <typename T>
void BoxToDelta(int box_num, const Tensor& ex_boxes, const Tensor& gt_boxes,
const std::vector<float>& weights, Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
auto box_delta_et = framework::EigenTensor<T, 2>::From(*box_delta);
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, gt_w, gt_h, gt_ctr_x, gt_ctr_y;
for (int64_t i = 0; i < box_num; ++i) {
ex_w = ex_boxes_et(i, 2) - ex_boxes_et(i, 0) + 1;
ex_h = ex_boxes_et(i, 3) - ex_boxes_et(i, 1) + 1;
ex_ctr_x = ex_boxes_et(i, 0) + 0.5 * ex_w;
ex_ctr_y = ex_boxes_et(i, 1) + 0.5 * ex_h;
gt_w = gt_boxes_et(i, 2) - gt_boxes_et(i, 0) + 1;
gt_h = gt_boxes_et(i, 3) - gt_boxes_et(i, 1) + 1;
gt_ctr_x = gt_boxes_et(i, 0) + 0.5 * gt_w;
gt_ctr_y = gt_boxes_et(i, 1) + 0.5 * gt_h;
box_delta_et(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0];
box_delta_et(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1];
box_delta_et(i, 2) = log(gt_w / ex_w) / ex_w / weights[2];
box_delta_et(i, 3) = log(gt_h / ex_h) / ex_h / weights[3];
}
}
template <typename T>
std::vector<std::vector<int>> SampleFgBgGt(
const platform::CPUDeviceContext& context, Tensor* iou,
......@@ -243,12 +219,11 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
Tensor* sampled_labels, Tensor* sampled_gts) {
int fg_num = fg_inds.size();
int bg_num = bg_inds.size();
int gt_num = fg_num + bg_num;
Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t;
int* fg_inds_data = fg_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
int* bg_inds_data = bg_inds_t.mutable_data<int>({bg_num}, context.GetPlace());
int* gt_box_inds_data =
gt_box_inds_t.mutable_data<int>({gt_num}, context.GetPlace());
gt_box_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
int* gt_label_inds_data =
gt_label_inds_t.mutable_data<int>({fg_num}, context.GetPlace());
std::copy(fg_inds.begin(), fg_inds.end(), fg_inds_data);
......@@ -303,18 +278,20 @@ std::vector<Tensor> SampleRoisForOneImage(
// Gather boxes and labels
Tensor sampled_boxes, sampled_labels, sampled_gts;
int boxes_num = fg_inds.size() + bg_inds.size();
int fg_num = fg_inds.size();
int bg_num = bg_inds.size();
int boxes_num = fg_num + bg_num;
framework::DDim bbox_dim({boxes_num, kBoxDim});
sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace());
sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace());
sampled_gts.mutable_data<T>(bbox_dim, context.GetPlace());
sampled_gts.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
GatherBoxesLabels<T>(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds,
gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts);
// Compute targets
Tensor bbox_targets_single;
bbox_targets_single.mutable_data<T>(bbox_dim, context.GetPlace());
BoxToDelta<T>(boxes_num, sampled_boxes, sampled_gts, bbox_reg_weights,
BoxToDelta<T>(fg_num, sampled_boxes, sampled_gts, nullptr, false,
&bbox_targets_single);
// Scale rois
......@@ -427,7 +404,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
auto rpn_rois_lod = rpn_rois->lod().back();
auto gt_classes_lod = gt_classes->lod().back();
auto gt_boxes_lod = gt_boxes->lod().back();
for (size_t i = 0; i < n; ++i) {
for (int i = 0; i < n; ++i) {
Tensor rpn_rois_slice =
rpn_rois->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]);
Tensor gt_classes_slice =
......
......@@ -311,8 +311,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel() / 4, 1},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
......@@ -421,7 +420,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_sel);
return std::make_pair(bbox_sel, scores_filter);
}
Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
......@@ -46,156 +47,219 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
auto in_dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"The rank of Input(DistMat) must be 2.");
ctx->SetOutputDim("LocationIndex", {-1});
ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("DistMat")->type()),
platform::CPUPlace());
}
};
template <typename T>
class RpnTargetAssignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* anchor_t = context.Input<Tensor>("Anchor"); // (H*W*A) * 4
auto* gt_bbox_t = context.Input<Tensor>("GtBox");
auto* dist_t = context.Input<LoDTensor>("DistMat");
auto* loc_index_t = context.Output<Tensor>("LocationIndex");
auto* score_index_t = context.Output<Tensor>("ScoreIndex");
auto* tgt_bbox_t = context.Output<Tensor>("TargetBBox");
auto* tgt_lbl_t = context.Output<Tensor>("TargetLabel");
auto lod = dist_t->lod().back();
int64_t batch_num = static_cast<int64_t>(lod.size() - 1);
int64_t anchor_num = dist_t->dims()[1];
PADDLE_ENFORCE_EQ(anchor_num, anchor_t->dims()[0]);
int rpn_batch_size = context.Attr<int>("rpn_batch_size_per_im");
float pos_threshold = context.Attr<float>("rpn_positive_overlap");
float neg_threshold = context.Attr<float>("rpn_negative_overlap");
float fg_fraction = context.Attr<float>("fg_fraction");
int fg_num_per_batch = static_cast<int>(rpn_batch_size * fg_fraction);
int64_t max_num = batch_num * anchor_num;
auto place = context.GetPlace();
tgt_bbox_t->mutable_data<T>({max_num, 4}, place);
auto* loc_index = loc_index_t->mutable_data<int>({max_num}, place);
auto* score_index = score_index_t->mutable_data<int>({max_num}, place);
Tensor tmp_tgt_lbl;
auto* tmp_lbl_data = tmp_tgt_lbl.mutable_data<int64_t>({max_num}, place);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, int64_t> iset;
iset(dev_ctx, &tmp_tgt_lbl, static_cast<int64_t>(-1));
std::random_device rnd;
std::minstd_rand engine;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed);
int fg_num = 0;
int bg_num = 0;
for (int i = 0; i < batch_num; ++i) {
Tensor dist = dist_t->Slice(lod[i], lod[i + 1]);
Tensor gt_bbox = gt_bbox_t->Slice(lod[i], lod[i + 1]);
auto fg_bg_gt = SampleFgBgGt(dev_ctx, dist, pos_threshold, neg_threshold,
rpn_batch_size, fg_num_per_batch, engine,
tmp_lbl_data + i * anchor_num);
int cur_fg_num = fg_bg_gt[0].size();
int cur_bg_num = fg_bg_gt[1].size();
std::transform(fg_bg_gt[0].begin(), fg_bg_gt[0].end(), loc_index,
[i, anchor_num](int d) { return d + i * anchor_num; });
memcpy(score_index, loc_index, cur_fg_num * sizeof(int));
std::transform(fg_bg_gt[1].begin(), fg_bg_gt[1].end(),
score_index + cur_fg_num,
[i, anchor_num](int d) { return d + i * anchor_num; });
// get target bbox deltas
if (cur_fg_num) {
Tensor fg_gt;
T* gt_data = fg_gt.mutable_data<T>({cur_fg_num, 4}, place);
Tensor tgt_bbox = tgt_bbox_t->Slice(fg_num, fg_num + cur_fg_num);
T* tgt_data = tgt_bbox.data<T>();
Gather<T>(anchor_t->data<T>(), 4,
reinterpret_cast<int*>(&fg_bg_gt[0][0]), cur_fg_num,
tgt_data);
Gather<T>(gt_bbox.data<T>(), 4, reinterpret_cast<int*>(&fg_bg_gt[2][0]),
cur_fg_num, gt_data);
BoxToDelta<T>(cur_fg_num, tgt_bbox, fg_gt, nullptr, false, &tgt_bbox);
}
loc_index += cur_fg_num;
score_index += cur_fg_num + cur_bg_num;
fg_num += cur_fg_num;
bg_num += cur_bg_num;
}
int lbl_num = fg_num + bg_num;
PADDLE_ENFORCE_LE(fg_num, max_num);
PADDLE_ENFORCE_LE(lbl_num, max_num);
tgt_bbox_t->Resize({fg_num, 4});
loc_index_t->Resize({fg_num});
score_index_t->Resize({lbl_num});
auto* lbl_data = tgt_lbl_t->mutable_data<int64_t>({lbl_num, 1}, place);
Gather<int64_t>(tmp_lbl_data, 1, score_index_t->data<int>(), lbl_num,
lbl_data);
}
private:
void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max,
const int row, const int col, const float pos_threshold,
const float neg_threshold, int64_t* target_label_data,
const float neg_threshold, int64_t* target_label,
std::vector<int>* fg_inds, std::vector<int>* bg_inds) const {
int fg_offset = fg_inds->size();
int bg_offset = bg_inds->size();
float epsilon = 0.0001;
for (int64_t i = 0; i < row; ++i) {
const T* v = dist_data + i * col;
T max_dist = *std::max_element(v, v + col);
T max = *std::max_element(v, v + col);
for (int64_t j = 0; j < col; ++j) {
T val = dist_data[i * col + j];
if (val == max_dist) target_label_data[j] = 1;
if (std::abs(max - v[j]) < epsilon) {
target_label[j] = 1;
}
}
}
// Pick the fg/bg and count the number
// Pick the fg/bg
const T* anchor_to_gt_max_data = anchor_to_gt_max.data<T>();
for (int64_t j = 0; j < col; ++j) {
if (anchor_to_gt_max.data<T>()[j] > pos_threshold) {
target_label_data[j] = 1;
} else if (anchor_to_gt_max.data<T>()[j] < neg_threshold) {
target_label_data[j] = 0;
if (anchor_to_gt_max_data[j] >= pos_threshold) {
target_label[j] = 1;
} else if (anchor_to_gt_max_data[j] < neg_threshold) {
target_label[j] = 0;
}
if (target_label_data[j] == 1) {
fg_inds->push_back(fg_offset + j);
} else if (target_label_data[j] == 0) {
bg_inds->push_back(bg_offset + j);
if (target_label[j] == 1) {
fg_inds->push_back(j);
} else if (target_label[j] == 0) {
bg_inds->push_back(j);
}
}
}
void ReservoirSampling(const int num, const int offset,
std::minstd_rand engine,
void ReservoirSampling(const int num, std::minstd_rand engine,
std::vector<int>* inds) const {
std::uniform_real_distribution<float> uniform(0, 1);
const int64_t size = static_cast<int64_t>(inds->size() - offset);
if (size > num) {
for (int64_t i = num; i < size; ++i) {
size_t len = inds->size();
if (len > static_cast<size_t>(num)) {
for (size_t i = num; i < len; ++i) {
int rng_ind = std::floor(uniform(engine) * i);
if (rng_ind < num)
std::iter_swap(inds->begin() + rng_ind + offset,
inds->begin() + i + offset);
std::iter_swap(inds->begin() + rng_ind, inds->begin() + i);
}
inds->resize(num);
}
}
void RpnTargetAssign(const framework::ExecutionContext& ctx,
const Tensor& dist, const float pos_threshold,
const float neg_threshold, const int rpn_batch_size,
const int fg_num, std::minstd_rand engine,
std::vector<int>* fg_inds, std::vector<int>* bg_inds,
int64_t* target_label_data) const {
// std::vector<std::vector<int>> RpnTargetAssign(
std::vector<std::vector<int>> SampleFgBgGt(
const platform::CPUDeviceContext& ctx, const Tensor& dist,
const float pos_threshold, const float neg_threshold,
const int rpn_batch_size, const int fg_num, std::minstd_rand engine,
int64_t* target_label) const {
auto* dist_data = dist.data<T>();
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
int fg_offset = fg_inds->size();
int bg_offset = bg_inds->size();
int row = dist.dims()[0];
int col = dist.dims()[1];
std::vector<int> fg_inds;
std::vector<int> bg_inds;
std::vector<int> gt_inds;
// Calculate the max IoU between anchors and gt boxes
Tensor anchor_to_gt_max;
anchor_to_gt_max.mutable_data<T>(
framework::make_ddim({static_cast<int64_t>(col), 1}),
platform::CPUPlace());
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto x = EigenMatrix<T>::From(dist);
auto x_col_max = EigenMatrix<T>::From(anchor_to_gt_max);
x_col_max.device(place) =
x.maximum(Eigen::DSizes<int, 1>(0))
.reshape(Eigen::DSizes<int, 2>(static_cast<int64_t>(col), 1));
// Map from anchor to gt box that has highest overlap
auto place = ctx.GetPlace();
Tensor anchor_to_gt_max, anchor_to_gt_argmax;
anchor_to_gt_max.mutable_data<T>({col}, place);
int* argmax = anchor_to_gt_argmax.mutable_data<int>({col}, place);
auto x = framework::EigenMatrix<T>::From(dist);
auto x_col_max = framework::EigenVector<T>::Flatten(anchor_to_gt_max);
auto x_col_argmax =
framework::EigenVector<int>::Flatten(anchor_to_gt_argmax);
x_col_max = x.maximum(Eigen::DSizes<int, 1>(0));
x_col_argmax = x.argmax(0).template cast<int>();
// Follow the Faster RCNN's implementation
ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold,
neg_threshold, target_label_data, fg_inds, bg_inds);
neg_threshold, target_label, &fg_inds, &bg_inds);
// Reservoir Sampling
ReservoirSampling(fg_num, fg_offset, engine, fg_inds);
int bg_num = rpn_batch_size - (fg_inds->size() - fg_offset);
ReservoirSampling(bg_num, bg_offset, engine, bg_inds);
}
ReservoirSampling(fg_num, engine, &fg_inds);
int fg_num2 = static_cast<int>(fg_inds.size());
int bg_num = rpn_batch_size - fg_num2;
ReservoirSampling(bg_num, engine, &bg_inds);
void Compute(const framework::ExecutionContext& context) const override {
auto* dist = context.Input<LoDTensor>("DistMat");
auto* loc_index = context.Output<Tensor>("LocationIndex");
auto* score_index = context.Output<Tensor>("ScoreIndex");
auto* tgt_lbl = context.Output<Tensor>("TargetLabel");
auto col = dist->dims()[1];
int64_t n = dist->lod().size() == 0UL
? 1
: static_cast<int64_t>(dist->lod().back().size() - 1);
if (dist->lod().size()) {
PADDLE_ENFORCE_EQ(dist->lod().size(), 1UL,
"Only support 1 level of LoD.");
gt_inds.reserve(fg_num2);
for (int i = 0; i < fg_num2; ++i) {
gt_inds.emplace_back(argmax[fg_inds[i]]);
}
int rpn_batch_size = context.Attr<int>("rpn_batch_size_per_im");
float pos_threshold = context.Attr<float>("rpn_positive_overlap");
float neg_threshold = context.Attr<float>("rpn_negative_overlap");
float fg_fraction = context.Attr<float>("fg_fraction");
int fg_num = static_cast<int>(rpn_batch_size * fg_fraction);
int64_t* target_label_data =
tgt_lbl->mutable_data<int64_t>({n * col, 1}, context.GetPlace());
std::vector<std::vector<int>> fg_bg_gt;
fg_bg_gt.emplace_back(fg_inds);
fg_bg_gt.emplace_back(bg_inds);
fg_bg_gt.emplace_back(gt_inds);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, int64_t> iset;
iset(dev_ctx, tgt_lbl, static_cast<int>(-1));
std::vector<int> fg_inds;
std::vector<int> bg_inds;
std::random_device rnd;
std::minstd_rand engine;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
engine.seed(seed);
if (n == 1) {
RpnTargetAssign(context, *dist, pos_threshold, neg_threshold,
rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
target_label_data);
} else {
auto lod = dist->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist->Slice(lod[i], lod[i + 1]);
RpnTargetAssign(context, one_ins, pos_threshold, neg_threshold,
rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
target_label_data + i * col);
}
}
int* loc_index_data = loc_index->mutable_data<int>(
{static_cast<int>(fg_inds.size())}, context.GetPlace());
int* score_index_data = score_index->mutable_data<int>(
{static_cast<int>(fg_inds.size() + bg_inds.size())},
context.GetPlace());
memcpy(loc_index_data, reinterpret_cast<int*>(&fg_inds[0]),
fg_inds.size() * sizeof(int));
memcpy(score_index_data, reinterpret_cast<int*>(&fg_inds[0]),
fg_inds.size() * sizeof(int));
memcpy(score_index_data + fg_inds.size(),
reinterpret_cast<int*>(&bg_inds[0]), bg_inds.size() * sizeof(int));
return fg_bg_gt;
}
};
class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Anchor",
"(Tensor) input anchor is a 2-D Tensor with shape [H*W*A, 4].");
AddInput("GtBox", "(LoDTensor) input groud-truth bbox with shape [K, 4].");
AddInput(
"DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......@@ -241,12 +305,15 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"ScoreIndex",
"(Tensor), The indexes of foreground and background anchors in all "
"RPN anchors(The rest anchors are ignored). The shape of the "
"ScoreIndex is [F + B], F and B depend on the value of input "
"tensor and attributes.");
AddOutput("TargetLabel",
"(Tensor<int64_t>), The target labels of each anchor with shape "
"[K * M, 1], "
"K and M is the same as they are in DistMat.");
"ScoreIndex is [F + B], F and B are sampled foreground and backgroud "
" number.");
AddOutput("TargetBBox",
"(Tensor<int64_t>), The target bbox deltas with shape "
"[F, 4], F is the sampled foreground number.");
AddOutput(
"TargetLabel",
"(Tensor<int64_t>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number.");
AddComment(R"DOC(
This operator can be, for given the IoU between the ground truth bboxes and the
anchors, to assign classification and regression targets to each prediction.
......
......@@ -31,7 +31,7 @@ static inline int NumBlocks(const int N) {
template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const int nthreads, const T* input_data, const T* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
......@@ -43,7 +43,7 @@ __global__ void GPUROIPoolForward(
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
const T* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
int roi_start_w = round(offset_input_rois[0] * spatial_scale);
int roi_start_h = round(offset_input_rois[1] * spatial_scale);
......@@ -93,7 +93,7 @@ __global__ void GPUROIPoolForward(
template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads, const int64_t* input_rois, const T* output_grad,
const int nthreads, const T* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, int* roi_batch_id_data,
......@@ -174,8 +174,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
channels, height, width, pooled_height, pooled_width,
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
......@@ -228,7 +228,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
if (output_grad_size > 0) {
GPUROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
output_grad_size, rois->data<T>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(),
......
......@@ -72,7 +72,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
T* output_data = out->mutable_data<T>(ctx.GetPlace());
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
const int64_t* rois_data = rois->data<int64_t>();
const T* rois_data = rois->data<T>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
int roi_start_w = round(rois_data[0] * spatial_scale);
......@@ -171,7 +171,7 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
}
}
const int64_t* rois_data = rois->data<int64_t>();
const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_enumerate_op.h"
namespace paddle {
namespace operators {
class SequenceEnumerateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(X) of SequenceEnumerate operator should not be null.");
const auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
x_dims.size(), 2UL,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1UL,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size});
ctx->ShareLoD("X", "Out");
}
};
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(2-D LoDTensor with the 2nd dimension equal to 1) "
"Input LoDTensor of SequenceEnumerate operator.");
AddOutput("Out",
"(2-D LoDTensor with the 2nd dimension equal to win_size) "
"Output LoDTensor of SequenceEnumerate operator.");
AddAttr<int>("win_size", "(int) The enumerate sequence window size.")
.AddCustomChecker([](const int& win_size) {
PADDLE_ENFORCE(win_size >= 2,
"The window size should be not less than 2.");
});
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0);
AddComment(R"DOC(
Sequence Enumerate Operator.
Generate a new sequence for the input index sequence, which enumerates all the
sub-sequences with length `win_size` of the input.
The enumerated sequence has the same 1st dimension with variable `input`, and
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp,
ops::SequenceEnumerateOpMaker);
REGISTER_OP_CPU_KERNEL(
sequence_enumerate,
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_enumerate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
const size_t lod_len, const int64_t win_size,
const int64_t pad_value, T* out_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_lod[lod_len - 1]) {
int end_idx = 0;
// Get LoD interval of index
for (int i = 1; i < lod_len; ++i) {
if (index < in_lod[i]) {
end_idx = in_lod[i];
break;
}
}
for (size_t i = 0; i < win_size; ++i) {
int word_pos = index + i;
out_data[index * win_size + i] =
word_pos < end_idx ? in_data[word_pos] : pad_value;
}
}
}
template <typename T>
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value");
auto in_dims = in->dims();
auto in_lod = in->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
/* Generate enumerate sequence set */
auto stream = context.cuda_device_context().stream();
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
// Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
sequence_enumerate,
paddle::operators::SequenceEnumerateOpCUDAKernel<int32_t>,
paddle::operators::SequenceEnumerateOpCUDAKernel<int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SequenceEnumerateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value");
auto in_dims = in->dims();
auto in_lod = in->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
// Generate enumerate sequence set
auto lod0 = in_lod[0];
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
size_t word_pos = idx + word_idx;
out_data[win_size * idx + word_idx] =
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -16,9 +16,6 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
#endif
namespace paddle {
namespace platform {
......@@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_len) {
std::lock_guard<std::mutex> lock(mtx_);
if (required_workspace_len > workspace_len_) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_);
}
~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
}
private:
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) {
return;
}
void* new_workspace = paddle::memory::Alloc(place_, required_workspace_len);
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_);
}
workspace_ = new_workspace;
workspace_len_ = required_workspace_len;
}
cudnnHandle_t cudnn_handle_;
void* workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device);
compute_capability = GetCUDAComputeCapability(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device);
......@@ -209,7 +154,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} else {
cudnn_handle_ = nullptr;
}
}
......@@ -217,6 +165,9 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
if (cudnn_handle_ != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
}
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -69,7 +69,6 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder;
class CUDADeviceContext : public DeviceContext {
public:
......@@ -97,11 +96,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
void RunCudnnFuncWithWorkspace(const std::function<void(void*)>& cudnn_func,
size_t workspace_len) const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......@@ -117,8 +111,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
int compute_capability;
......
......@@ -145,26 +145,23 @@ def rpn_target_assign(loc,
"""
helper = LayerHelper('rpn_target_assign', **locals())
# 1. Compute the regression target bboxes
target_bbox = box_coder(
prior_box=anchor_box,
prior_box_var=anchor_var,
target_box=gt_box,
code_type='encode_center_size',
box_normalized=False)
# 2. Compute overlaps between the prior boxes and the gt boxes overlaps
# Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box)
# 3. Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype=anchor_box.dtype)
score_index = helper.create_tmp_variable(dtype=anchor_box.dtype)
target_label = helper.create_tmp_variable(dtype=anchor_box.dtype)
# Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int64')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op(
type="rpn_target_assign",
inputs={'DistMat': iou},
inputs={'Anchor': anchor_box,
'GtBox': gt_box,
'DistMat': iou},
outputs={
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': target_label
'TargetLabel': target_label,
'TargetBBox': target_bbox,
},
attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im,
......@@ -173,16 +170,16 @@ def rpn_target_assign(loc,
'fg_fraction': fg_fraction
})
# 4. Reshape and gather the target entry
scores = nn.reshape(x=scores, shape=(-1, 2))
loc = nn.reshape(x=loc, shape=(-1, 4))
target_label = nn.reshape(x=target_label, shape=(-1, 1))
target_bbox = nn.reshape(x=target_bbox, shape=(-1, 4))
loc_index.stop_gradient = True
score_index.stop_gradient = True
target_label.stop_gradient = True
target_bbox.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index)
predicted_location = nn.gather(loc, loc_index)
target_label = nn.gather(target_label, score_index)
target_bbox = nn.gather(target_bbox, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox
......
......@@ -111,6 +111,7 @@ __all__ = [
'stack',
'pad2d',
'unstack',
'sequence_enumerate',
]
......@@ -5823,6 +5824,51 @@ def flatten(x, axis=1, name=None):
return out
def sequence_enumerate(input, win_size, pad_value=0, name=None):
"""
Generate a new sequence for the input index sequence, which enumerates all the
sub-sequences with length `win_size` of the input.
The enumerated sequence has the same 1st dimension with variable `input`, and
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
Args:
input (Variable): The input variable which is a index sequence.
win_size (int): The window size for enumerating all sub-sequences.
pad_value (int): The padding value, default 0.
Returns:
Variable: The enumerate sequence variable which is a LoDTensor.
Examples:
.. code-block:: python
x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1)
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
"""
helper = LayerHelper('sequence_enumerate', **locals())
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_enumerate',
inputs={'X': input},
outputs={'Out': out},
attrs={'win_size': win_size,
'pad_value': pad_value})
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
**SequenceMask Layer**
......@@ -5902,6 +5948,7 @@ def stack(x, axis=0):
helper.append_op(
type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis})
return out
......
......@@ -281,7 +281,7 @@ class TestRpnTargetAssign(unittest.TestCase):
gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32')
predicted_scores, predicted_location, target_label, target_bbox = layers.rpn_target_assign(
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc,
scores=scores,
anchor_box=anchor_box,
......@@ -292,15 +292,13 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3)
self.assertIsNotNone(predicted_scores)
self.assertIsNotNone(predicted_location)
self.assertIsNotNone(target_label)
self.assertIsNotNone(target_bbox)
assert predicted_scores.shape[1] == 2
assert predicted_location.shape[1] == 4
assert predicted_location.shape[1] == target_bbox.shape[1]
print(str(program))
self.assertIsNotNone(pred_scores)
self.assertIsNotNone(pred_loc)
self.assertIsNotNone(tgt_lbl)
self.assertIsNotNone(tgt_bbox)
assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1]
class TestGenerateProposals(unittest.TestCase):
......
......@@ -37,7 +37,7 @@ def fusion_gru(
h0,
wh,
np.zeros(
(1, wh.shape[1]), dtype='float64'),
(1, wh.shape[1]), dtype='float32'),
is_reverse,
act_state,
act_gate)
......@@ -62,15 +62,15 @@ class TestFusionGRUOp(OpTest):
T = sum(self.lod[0])
N = len(self.lod[0])
x = np.random.rand(T, self.M).astype('float64')
wx = np.random.rand(self.M, 3 * self.D).astype('float64')
wh = np.random.rand(self.D, 3 * self.D).astype('float64')
x = np.random.rand(T, self.M).astype('float32')
wx = np.random.rand(self.M, 3 * self.D).astype('float32')
wh = np.random.rand(self.D, 3 * self.D).astype('float32')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
1, 3 * self.D).astype('float32') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float32')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
N, self.D).astype('float32') if self.with_h0 else np.zeros(
(N, self.D), dtype='float32')
_, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse,
......@@ -93,7 +93,9 @@ class TestFusionGRUOp(OpTest):
}
def test_check_output(self):
self.check_output(atol=1e-8)
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output()
class TestFusionGRUOpNoInitial(TestFusionGRUOp):
......
......@@ -114,7 +114,9 @@ class TestFusionLSTMOp(OpTest):
}
def test_check_output(self):
self.check_output()
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output()
class TestFusionLSTMOpInit(TestFusionLSTMOp):
......
......@@ -177,8 +177,8 @@ def _box_to_delta(ex_boxes, gt_boxes, weights):
dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
dw = (np.log(gt_w / ex_w)) / ex_w / weights[2]
dh = (np.log(gt_h / ex_h)) / ex_h / weights[3]
dw = (np.log(gt_w / ex_w)) / weights[2]
dh = (np.log(gt_h / ex_h)) / weights[3]
targets = np.vstack([dx, dy, dw, dh]).transpose()
return targets
......
......@@ -549,6 +549,13 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_sequence_enumerate(self):
program = Program()
with program_guard(program):
x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -61,7 +61,7 @@ class TestROIPoolOp(OpTest):
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = roi[0]
roi_batch_id = int(roi[0])
roi_start_w = int(cpt.round(roi[1] * self.spatial_scale))
roi_start_h = int(cpt.round(roi[2] * self.spatial_scale))
roi_end_w = int(cpt.round(roi[3] * self.spatial_scale))
......@@ -125,7 +125,7 @@ class TestROIPoolOp(OpTest):
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("int64")
self.rois = np.array(rois).astype("float32")
def setUp(self):
self.op_type = "roi_pool"
......
......@@ -18,12 +18,17 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from test_anchor_generator_op import anchor_generator_in_python
from test_generate_proposal_labels import _generate_groundtruth
from test_generate_proposal_labels import _bbox_overlaps, _box_to_delta
def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap,
rpn_negative_overlap, fg_fraction):
iou = np.transpose(iou)
def rpn_target_assign(gt_anchor_iou, rpn_batch_size_per_im,
rpn_positive_overlap, rpn_negative_overlap, fg_fraction):
iou = np.transpose(gt_anchor_iou)
anchor_to_gt_max = iou.max(axis=1)
anchor_to_gt_argmax = iou.argmax(axis=1)
gt_to_anchor_argmax = iou.argmax(axis=0)
gt_to_anchor_max = iou[gt_to_anchor_argmax, np.arange(iou.shape[1])]
anchors_with_max_overlap = np.where(iou == gt_to_anchor_max)[0]
......@@ -42,59 +47,113 @@ def rpn_target_assign(iou, rpn_batch_size_per_im, rpn_positive_overlap,
num_bg = rpn_batch_size_per_im - np.sum(tgt_lbl == 1)
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
tgt_lbl[bg_inds] = 0
if len(bg_inds) > num_bg:
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
tgt_lbl[enable_inds] = 0
bg_inds = np.where(tgt_lbl == 0)[0]
tgt_lbl[bg_inds] = 0
loc_index = fg_inds
score_index = np.hstack((fg_inds, bg_inds))
tgt_lbl = np.expand_dims(tgt_lbl, axis=1)
return loc_index, score_index, tgt_lbl
gt_inds = anchor_to_gt_argmax[fg_inds]
return loc_index, score_index, tgt_lbl, gt_inds
def get_anchor(n, c, h, w):
input_feat = np.random.random((n, c, h, w)).astype('float32')
anchors, _ = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[32., 64.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
return anchors
def rpn_blob(anchor, gt_boxes, iou, lod, rpn_batch_size_per_im,
rpn_positive_overlap, rpn_negative_overlap, fg_fraction):
loc_indexes = []
score_indexes = []
tmp_tgt_labels = []
tgt_bboxes = []
anchor_num = anchor.shape[0]
batch_size = len(lod) - 1
for i in range(batch_size):
b, e = lod[i], lod[i + 1]
iou_slice = iou[b:e, :]
bboxes_slice = gt_boxes[b:e, :]
loc_idx, score_idx, tgt_lbl, gt_inds = rpn_target_assign(
iou_slice, rpn_batch_size_per_im, rpn_positive_overlap,
rpn_negative_overlap, fg_fraction)
fg_bboxes = bboxes_slice[gt_inds]
fg_anchors = anchor[loc_idx]
box_deltas = _box_to_delta(fg_anchors, fg_bboxes, [1., 1., 1., 1.])
if i == 0:
loc_indexes = loc_idx
score_indexes = score_idx
tmp_tgt_labels = tgt_lbl
tgt_bboxes = box_deltas
else:
loc_indexes = np.concatenate(
[loc_indexes, loc_idx + i * anchor_num])
score_indexes = np.concatenate(
[score_indexes, score_idx + i * anchor_num])
tmp_tgt_labels = np.concatenate([tmp_tgt_labels, tgt_lbl])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
tgt_labels = tmp_tgt_labels[score_indexes]
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels
class TestRpnTargetAssignOp(OpTest):
def setUp(self):
iou = np.random.random((10, 8)).astype("float32")
self.op_type = "rpn_target_assign"
self.inputs = {'DistMat': iou}
self.attrs = {
'rpn_batch_size_per_im': 256,
'rpn_positive_overlap': 0.95,
'rpn_negative_overlap': 0.3,
'fg_fraction': 0.25,
'fix_seed': True
}
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 256, 0.95, 0.3,
0.25)
self.outputs = {
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': tgt_lbl,
}
n, c, h, w = 2, 4, 14, 14
anchor = get_anchor(n, c, h, w)
gt_num = 10
anchor = anchor.reshape(-1, 4)
anchor_num = anchor.shape[0]
def test_check_output(self):
self.check_output()
im_shapes = [[64, 64], [64, 64]]
gt_box, lod = _generate_groundtruth(im_shapes, 3, 4)
bbox = np.vstack([v['boxes'] for v in gt_box])
iou = _bbox_overlaps(bbox, anchor)
anchor = anchor.astype('float32')
bbox = bbox.astype('float32')
iou = iou.astype('float32')
loc_index, score_index, tgt_bbox, tgt_lbl = rpn_blob(
anchor, bbox, iou, [0, 4, 8], 25600, 0.95, 0.03, 0.25)
class TestRpnTargetAssignOp2(OpTest):
def setUp(self):
iou = np.random.random((10, 20)).astype("float32")
self.op_type = "rpn_target_assign"
self.inputs = {'DistMat': iou}
self.inputs = {
'Anchor': anchor,
'GtBox': (bbox, [[4, 4]]),
'DistMat': (iou, [[4, 4]]),
}
self.attrs = {
'rpn_batch_size_per_im': 128,
'rpn_positive_overlap': 0.5,
'rpn_negative_overlap': 0.5,
'fg_fraction': 0.5,
'rpn_batch_size_per_im': 25600,
'rpn_positive_overlap': 0.95,
'rpn_negative_overlap': 0.03,
'fg_fraction': 0.25,
'fix_seed': True
}
loc_index, score_index, tgt_lbl = rpn_target_assign(iou, 128, 0.5, 0.5,
0.5)
self.outputs = {
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': tgt_lbl,
'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': tgt_lbl.astype('int64'),
}
def test_check_output(self):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
lod0 = [0]
for i in range(0, len(in_lod[0])):
lod0.append(lod0[i] + in_lod[0][i])
out_seq = []
for i in range(0, len(lod0) - 1):
for idx in range(lod0[i], lod0[i + 1]):
single_seq = []
for word_idx in range(win_size):
word_pos = idx + word_idx
dat = input_seq[word_pos] if word_pos < lod0[i+1] \
else pad_value
single_seq.append(dat)
out_seq.append(single_seq)
return out_seq
class TestSequenceEnumerateOp(OpTest):
def setUp(self):
self.op_type = "sequence_enumerate"
self.init_test_case()
self.inputs = {'X': (self.in_seq, self.lod)}
self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value}
self.outputs = {'Out': (self.out_seq, self.lod)}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int64")
class TestSequenceEnumerateOpLargeWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 30
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 5
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
if __name__ == "__main__":
unittest.main()
......@@ -1096,7 +1096,8 @@ class DistributeTranspiler(object):
self.table_name]
zero_dim = int(
math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints)))
math.ceil(origin_param_var.shape[0] / float(
len(self.pserver_endpoints))))
table_shape = list(origin_param_var.shape)
table_shape[0] = zero_dim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册