提交 b2df6de8 编写于 作者: J jerrywgz

fix potential hung in generate proposals, test=develop

上级 765c70a1
...@@ -286,7 +286,8 @@ static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -286,7 +286,8 @@ static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
} }
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace()); int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(), memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, 0); sizeof(int) * num_to_keep, ctx.stream());
ctx.Wait();
} }
template <typename T> template <typename T>
...@@ -329,7 +330,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -329,7 +330,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
int keep_num; int keep_num;
const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(platform::CPUPlace(), &keep_num, gpu_place, memory::Copy(platform::CPUPlace(), &keep_num, gpu_place,
keep_num_t.data<int>(), sizeof(int), 0); keep_num_t.data<int>(), sizeof(int), ctx.stream());
ctx.Wait();
keep_index.Resize({keep_num}); keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter; Tensor scores_filter, proposals_filter;
...@@ -438,9 +440,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -438,9 +440,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
Tensor &scores = box_score_pair.second; Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place, memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(), 0); proposals.data<T>(), sizeof(T) * proposals.numel(),
dev_ctx.stream());
memory::Copy(place, rpn_roi_probs_data + num_proposals, place, memory::Copy(place, rpn_roi_probs_data + num_proposals, place,
scores.data<T>(), sizeof(T) * scores.numel(), 0); scores.data<T>(), sizeof(T) * scores.numel(),
dev_ctx.stream());
dev_ctx.Wait();
num_proposals += proposals.dims()[0]; num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals); offset.emplace_back(num_proposals);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册