未验证 提交 3e55f255 编写于 作者: Z zhupengyang 提交者: GitHub

fix delete_repeated_ops_pass, fix multiclass_nms3 (#56434)

上级 8495377a
...@@ -122,6 +122,10 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps( ...@@ -122,6 +122,10 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps(
Graph* graph) { Graph* graph) {
VLOG(4) << "handle DeleteRepeatedOps"; VLOG(4) << "handle DeleteRepeatedOps";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
// in_var node may be deleted by the previous detected subgraph
if (graph->Nodes().count(in_var) == 0) {
return;
}
std::vector<std::string> invalid_out_ops{ std::vector<std::string> invalid_out_ops{
"while", "conditional_block", "fetch"}; "while", "conditional_block", "fetch"};
......
...@@ -59,10 +59,18 @@ void MultiClassNMSKernel(const Context& ctx, ...@@ -59,10 +59,18 @@ void MultiClassNMSKernel(const Context& ctx,
rois_num_vec.clear(); rois_num_vec.clear();
if (is_lod) { if (is_lod) {
if (has_rois_num) { if (has_rois_num) {
phi::DenseTensor rois_num_host;
rois_num_host.Resize(rois_num.get_ptr()->dims());
ctx.template HostAlloc<int>(&rois_num_host);
phi::Copy(ctx,
*rois_num.get_ptr(),
rois_num_host.place(),
false,
&rois_num_host);
n = rois_num.get_ptr()->numel(); n = rois_num.get_ptr()->numel();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
rois_num_vec.push_back(rois_num.get_ptr()->data<int>()[i]); rois_num_vec.push_back(rois_num_host.data<int>()[i]);
boxes_count += rois_num.get_ptr()->data<int>()[i]; boxes_count += rois_num_host.data<int>()[i];
} }
} else { } else {
auto lod = bboxes.lod().back(); auto lod = bboxes.lod().back();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册