提交 65b641bf 编写于 作者: S sweetsky0901

add detection_output op

上级 b41894d1
...@@ -21,42 +21,37 @@ class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -21,42 +21,37 @@ class Detection_output_OpMaker : public framework::OpProtoAndCheckerMaker {
Detection_output_OpMaker(framework::OpProto* proto, Detection_output_OpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput("Loc",
"Loc", "(Tensor) The input tensor of detection_output operator. "
"(Tensor) The input tensor of detection_output operator. " "The format of input tensor is kNCHW. Where K is priorbox point "
"The format of input tensor is NCHW. Where N is batch size, C is the " "numbers,"
"number of channels, H and W is the height and width of feature."); "N is How many boxes are there on each point, "
AddInput( "C is 4, H and W both are 1.");
"Conf", AddInput("Conf",
"(Tensor) The input tensor of detection_output operator. " "(Tensor) The input tensor of detection_output operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is kNCHW. Where K is priorbox point "
"number of channels, H and W is the height and width of feature."); "numbers,"
AddInput( "N is How many boxes are there on each point, "
"PriorBox", "C is the number of classes, H and W both are 1.");
"(Tensor) The input tensor of detection_output operator. " AddInput("PriorBox",
"The format of input tensor is NCHW. Where N is batch size, C is the " "(Tensor) The input tensor of detection_output operator. "
"number of channels, H and W is the height and width of feature."); "The format of input tensor is the position and variance "
"of the boxes");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of detection_output operator." "(Tensor) The output tensor of detection_output operator.");
"N * M." AddAttr<int>("background_label_id",
"M = C * H * W"); "(int), the attr of detection_output operator");
AddAttr<int>("background_label_id", "(int), multi level pooling"); AddAttr<int>("num_classes",
AddAttr<int>("num_classes", "(int), multi level pooling"); "(int), the attr of detection_output operator");
AddAttr<float>("nms_threshold", "(int), multi level pooling"); AddAttr<float>("nms_threshold",
AddAttr<float>("confidence_threshold", "(int), multi level pooling"); "(float), the attr of detection_output operator");
AddAttr<int>("top_k", "(int), multi level pooling"); AddAttr<float>("confidence_threshold",
AddAttr<int>("nms_top_k", "(int), multi level pooling"); "(float), the attr of detection_output operator");
AddAttr<int>("top_k", "(int), the attr of detection_output operator");
AddAttr<int>("nms_top_k", "(int), the attr of detection_output operator");
AddComment(R"DOC( AddComment(R"DOC(
"Does spatial pyramid pooling on the input image by taking the max, detection output for SSD(single shot multibox detector)
etc. within regions so that the result vector of different sized
images are of the same size
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(H_{out}, W_{out})$
Where
$$
H_{out} = N \\
W_{out} = (((4^pyramid_height) - 1) / (4 - 1))$ * C_{in}
$$
)DOC"); )DOC");
} }
}; };
......
...@@ -18,10 +18,34 @@ limitations under the License. */ ...@@ -18,10 +18,34 @@ limitations under the License. */
#include "paddle/operators/math/detection_util.h" #include "paddle/operators/math/detection_util.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/softmax.h" #include "paddle/operators/math/softmax.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
void transpose_fun(const platform::DeviceContext& context,
const framework::Tensor& src, framework::Tensor* dst) {
int input_nums = src.dims()[0];
int offset = 0;
for (int j = 0; j < input_nums; ++j) {
framework::Tensor in_p_tensor = src.Slice(j, j + 1);
std::vector<int64_t> shape_vec(
{in_p_tensor.dims()[0], in_p_tensor.dims()[1], in_p_tensor.dims()[3],
in_p_tensor.dims()[4], in_p_tensor.dims()[2]});
framework::DDim shape(framework::make_ddim(shape_vec));
framework::Tensor in_p_tensor_transpose;
in_p_tensor_transpose.mutable_data<T>(shape, context.GetPlace());
std::vector<int> shape_axis({0, 1, 3, 4, 2});
math::Transpose<Place, T, 5> trans5;
trans5(context, in_p_tensor, &in_p_tensor_transpose, shape_axis);
auto dst_stride = framework::stride(dst->dims());
auto src_stride = framework::stride(in_p_tensor_transpose.dims());
StridedMemcpy<T>(context, in_p_tensor_transpose.data<T>(), src_stride,
in_p_tensor_transpose.dims(), dst_stride,
dst->data<T>() + offset);
offset += in_p_tensor_transpose.dims()[4] * src_stride[4];
}
}
template <typename Place, typename T>
class Detection_output_Kernel : public framework::OpKernel<T> { class Detection_output_Kernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -37,77 +61,51 @@ class Detection_output_Kernel : public framework::OpKernel<T> { ...@@ -37,77 +61,51 @@ class Detection_output_Kernel : public framework::OpKernel<T> {
float nms_threshold = context.template Attr<float>("nms_threshold"); float nms_threshold = context.template Attr<float>("nms_threshold");
float confidence_threshold = float confidence_threshold =
context.template Attr<float>("confidence_threshold"); context.template Attr<float>("confidence_threshold");
int batch_size = in_conf->dims()[1];
int input_num = in_loc->dims()[0];
int batch_size = in_loc->dims()[1];
int channels = in_loc->dims()[2];
int height = in_loc->dims()[3];
int weight = in_loc->dims()[4];
int loc_sum_size = in_loc->numel();
int conf_sum_size = in_conf->numel(); int conf_sum_size = in_conf->numel();
std::vector<int64_t> loc_shape_vec({1, loc_sum_size}); // for softmax
std::vector<int64_t> conf_shape_vec( std::vector<int64_t> conf_shape_softmax_vec(
{conf_sum_size / num_classes, num_classes}); {conf_sum_size / num_classes, num_classes});
framework::DDim conf_shape_softmax(
framework::make_ddim(conf_shape_softmax_vec));
// for knchw => nhwc
std::vector<int64_t> loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3],
in_loc->dims()[4], in_loc->dims()[2]});
std::vector<int64_t> conf_shape_vec({1, in_conf->dims()[1],
in_conf->dims()[3], in_conf->dims()[4],
in_conf->dims()[2]});
framework::DDim loc_shape(framework::make_ddim(loc_shape_vec)); framework::DDim loc_shape(framework::make_ddim(loc_shape_vec));
framework::DDim conf_shape(framework::make_ddim(conf_shape_vec)); framework::DDim conf_shape(framework::make_ddim(conf_shape_vec));
framework::Tensor loc_tensor; framework::Tensor loc_tensor;
framework::Tensor conf_tensor; framework::Tensor conf_tensor;
loc_tensor.Resize(loc_shape);
conf_tensor.Resize(conf_shape);
loc_tensor.mutable_data<T>(loc_shape, context.GetPlace()); loc_tensor.mutable_data<T>(loc_shape, context.GetPlace());
conf_tensor.mutable_data<T>(conf_shape, context.GetPlace()); conf_tensor.mutable_data<T>(conf_shape, context.GetPlace());
// for cpu
framework::Tensor loc_cpu; framework::Tensor loc_cpu;
framework::Tensor conf_cpu; framework::Tensor conf_cpu;
framework::Tensor priorbox_cpu; framework::Tensor priorbox_cpu;
const T* in_loc_data = in_loc->data<T>();
const T* in_conf_data = in_conf->data<T>();
T* loc_data;
T* conf_data;
const T* priorbox_data = in_priorbox->data<T>(); const T* priorbox_data = in_priorbox->data<T>();
transpose_fun<Place, T>(context.device_context(), *in_loc, &loc_tensor);
transpose_fun<Place, T>(context.device_context(), *in_conf, &conf_tensor);
conf_tensor.Resize(conf_shape_softmax);
math::SoftmaxFunctor<Place, T>()(context.device_context(), &conf_tensor,
&conf_tensor);
T* loc_data = loc_tensor.data<T>();
T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) { if (platform::is_gpu_place(context.GetPlace())) {
loc_cpu.mutable_data<T>(in_loc->dims(), platform::CPUPlace()); loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(*in_loc, platform::CPUPlace(), framework::CopyFrom(loc_tensor, platform::CPUPlace(),
context.device_context(), &loc_cpu); context.device_context(), &loc_cpu);
in_loc_data = loc_cpu.data<T>(); loc_data = loc_cpu.data<T>();
conf_cpu.mutable_data<T>(in_conf->dims(), platform::CPUPlace()); conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(*in_conf, platform::CPUPlace(), framework::CopyFrom(conf_tensor, platform::CPUPlace(),
context.device_context(), &conf_cpu); context.device_context(), &conf_cpu);
in_conf_data = conf_cpu.data<T>(); conf_data = conf_cpu.data<T>();
priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace()); priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace());
framework::CopyFrom(*in_priorbox, platform::CPUPlace(), framework::CopyFrom(*in_priorbox, platform::CPUPlace(),
context.device_context(), &priorbox_cpu); context.device_context(), &priorbox_cpu);
priorbox_data = priorbox_cpu.data<T>(); priorbox_data = priorbox_cpu.data<T>();
loc_tensor.mutable_data<T>(loc_shape, platform::CPUPlace());
conf_tensor.mutable_data<T>(conf_shape, platform::CPUPlace());
}
T* loc_tensor_data = loc_tensor.data<T>();
T* conf_tensor_data = conf_tensor.data<T>();
for (int i = 0; i < input_num; ++i) {
math::appendWithPermute<T>(in_loc_data, input_num, batch_size, channels,
height, weight, loc_tensor_data);
math::appendWithPermute<T>(in_conf_data, input_num, batch_size, channels,
height, weight, conf_tensor_data);
}
loc_data = loc_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
framework::Tensor conf_gpu;
conf_gpu.Resize(conf_shape);
conf_gpu.mutable_data<T>(conf_shape, context.GetPlace());
framework::CopyFrom(conf_tensor, platform::GPUPlace(),
context.device_context(), &conf_gpu);
// softmax
math::SoftmaxFunctor<Place, T>()(context.device_context(), &conf_gpu,
&conf_gpu);
conf_tensor.mutable_data<T>(conf_gpu.dims(), platform::CPUPlace());
framework::CopyFrom(conf_gpu, platform::CPUPlace(),
context.device_context(), &conf_tensor);
} else {
// softmax
math::SoftmaxFunctor<Place, T>()(context.device_context(), &conf_tensor,
&conf_tensor);
} }
conf_data = conf_tensor.data<T>();
// get decode bboxes // get decode bboxes
size_t num_priors = in_priorbox->numel() / 8; size_t num_priors = in_priorbox->numel() / 8;
std::vector<std::vector<operators::math::BBox<T>>> all_decoded_bboxes; std::vector<std::vector<operators::math::BBox<T>>> all_decoded_bboxes;
......
...@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <map>
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
struct BBox { struct BBox {
BBox(T x_min, T y_min, T x_max, T y_max) BBox(T x_min, T y_min, T x_max, T y_max)
...@@ -49,31 +49,47 @@ struct BBox { ...@@ -49,31 +49,47 @@ struct BBox {
bool is_difficult; bool is_difficult;
}; };
// KNCHW ==> NHWC // KNCHW ==> NHWC
// template <typename T>
template <typename T> template <typename T>
int appendWithPermute(const T* input_data, int input_nums, int batch_size, void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
int channels, int height, int weight, T* output_data) { std::vector<BBox<T>>& bbox_vec);
int image_size = height * weight; template <typename T>
int numel = input_nums * batch_size * channels * height * weight; void getBBoxVarFromPriorData(const T* prior_data, const size_t num,
int offset = 0; std::vector<std::vector<T>>& var_vec);
for (int p = 0; p < input_nums; ++p) { template <typename T>
int in_p_offset = p * batch_size * channels * image_size; BBox<T> decodeBBoxWithVar(BBox<T>& prior_bbox,
for (int n = 0; n < batch_size; ++n) { const std::vector<T>& prior_bbox_var,
int in_n_offset = n * channels * image_size; const std::vector<T>& loc_pred_data);
int out_n_offset = n * numel / batch_size + offset; template <typename T1, typename T2>
int in_stride = image_size; bool sortScorePairDescend(const std::pair<T1, T2>& pair1,
int out_stride = channels; const std::pair<T1, T2>& pair2);
const T* in_data = input_data + in_p_offset + in_n_offset; template <typename T>
T* out_data = output_data + out_n_offset; bool sortScorePairDescend(const std::pair<T, BBox<T>>& pair1,
for (int c = 0; c < channels; ++c) { const std::pair<T, BBox<T>>& pair2);
for (int i = 0; i < image_size; ++i) { template <typename T>
out_data[out_stride * i + c] = in_data[c * in_stride + i]; T jaccardOverlap(const BBox<T>& bbox1, const BBox<T>& bbox2);
}
} template <typename T>
} void applyNMSFast(const std::vector<BBox<T>>& bboxes, const T* conf_score_data,
offset += image_size * channels; size_t class_idx, size_t top_k, T conf_threshold,
} T nms_threshold, size_t num_priors, size_t num_classes,
return 0; std::vector<size_t>* indices);
} template <typename T>
int getDetectionIndices(
const T* conf_data, const size_t num_priors, const size_t num_classes,
const size_t background_label_id, const size_t batch_size,
const T conf_threshold, const size_t nms_top_k, const T nms_threshold,
const size_t top_k,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes,
std::vector<std::map<size_t, std::vector<size_t>>>* all_detection_indices);
template <typename T>
BBox<T> clipBBox(const BBox<T>& bbox);
template <typename T>
void getDetectionOutput(
const T* conf_data, const size_t num_kept, const size_t num_priors,
const size_t num_classes, const size_t batch_size,
const std::vector<std::map<size_t, std::vector<size_t>>>& all_indices,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes, T* out_data);
template <typename T> template <typename T>
void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes, void getBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
std::vector<BBox<T>>& bbox_vec) { std::vector<BBox<T>>& bbox_vec) {
...@@ -136,9 +152,6 @@ bool sortScorePairDescend(const std::pair<T1, T2>& pair1, ...@@ -136,9 +152,6 @@ bool sortScorePairDescend(const std::pair<T1, T2>& pair1,
return pair1.first > pair2.first; return pair1.first > pair2.first;
} }
template <typename T> template <typename T>
bool sortScorePairDescend(const std::pair<T, BBox<T>>& pair1,
const std::pair<T, BBox<T>>& pair2);
template <typename T>
T jaccardOverlap(const BBox<T>& bbox1, const BBox<T>& bbox2) { T jaccardOverlap(const BBox<T>& bbox1, const BBox<T>& bbox2) {
if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min || if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min ||
bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) { bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) {
...@@ -281,7 +294,6 @@ void getDetectionOutput( ...@@ -281,7 +294,6 @@ void getDetectionOutput(
} }
} }
} }
// out.copyFrom(out_data, num_kept * 7);
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -8,22 +8,24 @@ class TestUnpoolOp(OpTest): ...@@ -8,22 +8,24 @@ class TestUnpoolOp(OpTest):
self.op_type = "detection_output" self.op_type = "detection_output"
self.init_test_case() self.init_test_case()
#loc = np.zeros((1, 4, 4, 1, 1)) #loc.shape ((1, 4, 4, 1, 1))
#conf = np.zero((1, 4, 2, 1, 1)) #conf.shape ((1, 4, 2, 1, 1))
loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], [[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]], [[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]]) [[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]])
conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]]], conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]],
[[[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]]) [[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]])
priorbox = np.array([0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2,\ priorbox = np.array([
0.2, 0.2, 0.6, 0.6, 0.1, 0.1, 0.2, 0.2,\ 0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1,
0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2,\ 0.1, 0.2, 0.2, 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2, 0.4, 0.4,
0.4, 0.4, 0.8, 0.8, 0.1, 0.1, 0.2, 0.2]) 0.8, 0.8, 0.1, 0.1, 0.2, 0.2
])
output = np.array([0, 1, 0.68997443, 0.099959746, 0.099959746,\
0.50804031, 0.50804031]) output = np.array([
0, 1, 0.68997443, 0.099959746, 0.099959746, 0.50804031, 0.50804031
])
self.inputs = { self.inputs = {
'Loc': loc.astype('float32'), 'Loc': loc.astype('float32'),
'Conf': conf.astype('float32'), 'Conf': conf.astype('float32'),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册