From 9eb2d7b3e1c976ad179561ca62be19f41a7584a7 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Thu, 24 Jan 2019 04:28:41 +0000 Subject: [PATCH] refine code, test=develop --- .../operators/detection/multiclass_nms_op.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 43d638228..265bfc6c7 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -171,14 +171,17 @@ void SliceOneClass(const platform::DeviceContext& ctx, const T* items_data = items.data(); const int64_t num_item = items.dims()[0]; const int class_num = items.dims()[1]; - int item_size = 1; if (items.dims().size() == 3) { - item_size = items.dims()[2]; - } - for (int i = 0; i < num_item; ++i) { - std::memcpy(item_data + i * item_size, - items_data + i * class_num * item_size + class_id * item_size, - sizeof(T) * item_size); + int item_size = items.dims()[2]; + for (int i = 0; i < num_item; ++i) { + std::memcpy(item_data + i * item_size, + items_data + i * class_num * item_size + class_id * item_size, + sizeof(T) * item_size); + } + } else { + for (int i = 0; i < num_item; ++i) { + item_data[i] = items_data[i * class_num + class_id]; + } } } -- GitLab