未验证 提交 b9ee846e 编写于 作者: Z zyfncg 提交者: GitHub

Add roi_align yaml and unittest (#41402)

* add roi_align yaml

* fix bug
上级 1f829f6e
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
import unittest
import numpy as np
import math
......@@ -30,6 +31,7 @@ class TestROIAlignOp(OpTest):
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': self.boxes_num
}
self.attrs = {
'spatial_scale': self.spatial_scale,
......@@ -170,16 +172,19 @@ class TestROIAlignOp(OpTest):
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float64")
self.boxes_num = np.array(
[bno + 1 for bno in range(self.batch_size)]).astype('int32')
def setUp(self):
self.op_type = "roi_align"
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned: paddle.vision.ops.roi_align(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale, sampling_ratio, aligned)
self.set_data()
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestROIAlignInLodOp(TestROIAlignOp):
......
......@@ -1344,7 +1344,7 @@
param : [n, dtype]
data_type : dtype
backend : place
- api : reciprocal
args : (Tensor x)
output : Tensor
......@@ -1386,6 +1386,16 @@
intermediate : xshape
backward: reshape_grad
- api : roi_align
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned)
output : Tensor
infer_meta :
func : RoiAlignInferMeta
kernel :
func : roi_align
optional : boxes_num
backward : roi_align_grad
- api : roll
args : (Tensor x, IntArray shifts, int64_t[] axis)
output : Tensor(out)
......
......@@ -407,7 +407,7 @@
param : [x]
kernel :
func : expand_as_grad
- backward_api : expm1_grad
forward : expm1 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......@@ -994,6 +994,17 @@
backend: out_grad
layout: out_grad
- backward_api : roi_align_grad
forward : roi_align (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) -> Tensor(out)
args : (Tensor x, Tensor boxes, Tensor boxes_num, Tensor out_grad, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : roi_align_grad
optional : boxes_num
- backward_api : roll_grad
forward : roll(Tensor x, IntArray shifts, int64_t[] axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray shifts, int64_t[] axis)
......
......@@ -176,6 +176,7 @@ def source_include(header_file_path):
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
......
......@@ -19,7 +19,7 @@ from ..fluid import core, layers
from ..fluid.layers import nn, utils
from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D
from ..fluid.initializer import Normal
from ..fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.common_ops_import import *
from paddle import _C_ops
......@@ -1224,7 +1224,12 @@ def roi_align(x,
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
if _non_static_mode():
if in_dygraph_mode():
assert boxes_num is not None, "boxes_num should not be None in dygraph mode."
return _C_ops.final_state_roi_align(x, boxes, boxes_num, pooled_height,
pooled_width, spatial_scale,
sampling_ratio, aligned)
if _in_legacy_dygraph():
assert boxes_num is not None, "boxes_num should not be None in dygraph mode."
align_out = _C_ops.roi_align(
x, boxes, boxes_num, "pooled_height", pooled_height, "pooled_width",
......
{
"phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout"],
"phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout", "roi_align"],
"phi_kernels":["equal_all"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册