未验证 提交 b9ee846e 编写于 作者: Z zyfncg 提交者: GitHub

Add roi_align yaml and unittest (#41402)

* add roi_align yaml

* fix bug
上级 1f829f6e
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import paddle
import unittest import unittest
import numpy as np import numpy as np
import math import math
...@@ -30,6 +31,7 @@ class TestROIAlignOp(OpTest): ...@@ -30,6 +31,7 @@ class TestROIAlignOp(OpTest):
self.inputs = { self.inputs = {
'X': self.x, 'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod), 'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': self.boxes_num
} }
self.attrs = { self.attrs = {
'spatial_scale': self.spatial_scale, 'spatial_scale': self.spatial_scale,
...@@ -170,16 +172,19 @@ class TestROIAlignOp(OpTest): ...@@ -170,16 +172,19 @@ class TestROIAlignOp(OpTest):
rois.append(roi) rois.append(roi)
self.rois_num = len(rois) self.rois_num = len(rois)
self.rois = np.array(rois).astype("float64") self.rois = np.array(rois).astype("float64")
self.boxes_num = np.array(
[bno + 1 for bno in range(self.batch_size)]).astype('int32')
def setUp(self): def setUp(self):
self.op_type = "roi_align" self.op_type = "roi_align"
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned: paddle.vision.ops.roi_align(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale, sampling_ratio, aligned)
self.set_data() self.set_data()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_eager=True)
class TestROIAlignInLodOp(TestROIAlignOp): class TestROIAlignInLodOp(TestROIAlignOp):
......
...@@ -1344,7 +1344,7 @@ ...@@ -1344,7 +1344,7 @@
param : [n, dtype] param : [n, dtype]
data_type : dtype data_type : dtype
backend : place backend : place
- api : reciprocal - api : reciprocal
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -1386,6 +1386,16 @@ ...@@ -1386,6 +1386,16 @@
intermediate : xshape intermediate : xshape
backward: reshape_grad backward: reshape_grad
- api : roi_align
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned)
output : Tensor
infer_meta :
func : RoiAlignInferMeta
kernel :
func : roi_align
optional : boxes_num
backward : roi_align_grad
- api : roll - api : roll
args : (Tensor x, IntArray shifts, int64_t[] axis) args : (Tensor x, IntArray shifts, int64_t[] axis)
output : Tensor(out) output : Tensor(out)
......
...@@ -407,7 +407,7 @@ ...@@ -407,7 +407,7 @@
param : [x] param : [x]
kernel : kernel :
func : expand_as_grad func : expand_as_grad
- backward_api : expm1_grad - backward_api : expm1_grad
forward : expm1 (Tensor x) -> Tensor(out) forward : expm1 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
...@@ -994,6 +994,17 @@ ...@@ -994,6 +994,17 @@
backend: out_grad backend: out_grad
layout: out_grad layout: out_grad
- backward_api : roi_align_grad
forward : roi_align (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) -> Tensor(out)
args : (Tensor x, Tensor boxes, Tensor boxes_num, Tensor out_grad, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : roi_align_grad
optional : boxes_num
- backward_api : roll_grad - backward_api : roll_grad
forward : roll(Tensor x, IntArray shifts, int64_t[] axis) -> Tensor(out) forward : roll(Tensor x, IntArray shifts, int64_t[] axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray shifts, int64_t[] axis) args : (Tensor x, Tensor out_grad, IntArray shifts, int64_t[] axis)
......
...@@ -176,6 +176,7 @@ def source_include(header_file_path): ...@@ -176,6 +176,7 @@ def source_include(header_file_path):
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
......
...@@ -19,7 +19,7 @@ from ..fluid import core, layers ...@@ -19,7 +19,7 @@ from ..fluid import core, layers
from ..fluid.layers import nn, utils from ..fluid.layers import nn, utils
from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D
from ..fluid.initializer import Normal from ..fluid.initializer import Normal
from ..fluid.framework import _non_static_mode, in_dygraph_mode from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.common_ops_import import * from paddle.common_ops_import import *
from paddle import _C_ops from paddle import _C_ops
...@@ -1224,7 +1224,12 @@ def roi_align(x, ...@@ -1224,7 +1224,12 @@ def roi_align(x,
output_size = (output_size, output_size) output_size = (output_size, output_size)
pooled_height, pooled_width = output_size pooled_height, pooled_width = output_size
if _non_static_mode(): if in_dygraph_mode():
assert boxes_num is not None, "boxes_num should not be None in dygraph mode."
return _C_ops.final_state_roi_align(x, boxes, boxes_num, pooled_height,
pooled_width, spatial_scale,
sampling_ratio, aligned)
if _in_legacy_dygraph():
assert boxes_num is not None, "boxes_num should not be None in dygraph mode." assert boxes_num is not None, "boxes_num should not be None in dygraph mode."
align_out = _C_ops.roi_align( align_out = _C_ops.roi_align(
x, boxes, boxes_num, "pooled_height", pooled_height, "pooled_width", x, boxes, boxes_num, "pooled_height", pooled_height, "pooled_width",
......
{ {
"phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout"], "phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout", "roi_align"],
"phi_kernels":["equal_all"] "phi_kernels":["equal_all"]
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册