提交 8eb98cb3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2988 Add attr “roi_end_mode" in ROIAlign for both GE and VM backends.

Merge pull request !2988 from liuxiao93/ROIAlign
......@@ -610,7 +610,8 @@ OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}};
ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits<int>())},
{"pooled_width", ATTR_DESC(pooled_width, AnyTraits<int>())},
{"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits<float>())},
{"sample_num", ATTR_DESC(sample_num, AnyTraits<int>())}};
{"sample_num", ATTR_DESC(sample_num, AnyTraits<int>())},
{"roi_end_mode", ATTR_DESC(roi_end_mode, AnyTraits<int>())}};
// ROIAlignGrad
INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}};
......
......@@ -27,7 +27,7 @@ roi_align_op_info = TBERegOp("ROIAlign") \
.attr("pooled_height", "required", "int", "all") \
.attr("pooled_width", "required", "int", "all") \
.attr("sample_num", "optional", "int", "all", "2") \
.attr("roi_end_mode", "optional", "0,1", "1") \
.attr("roi_end_mode", "optional", "int", "0,1", "1") \
.input(0, "features", False, "required", "all") \
.input(1, "rois", False, "required", "all") \
.input(2, "rois_n", False, "optional", "all") \
......
......@@ -2695,6 +2695,7 @@ class ROIAlign(PrimitiveWithInfer):
feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the
input feature map, the `spatial_scale` should be `fea_h / ori_h`.
sample_num (int): Number of sampling points. Default: 2.
roi_end_mode (int): Number must be 0 or 1. Default: 1.
Inputs:
- **features** (Tensor) - The input features, whose shape should be `(N, C, H, W)`.
......@@ -2717,16 +2718,19 @@ class ROIAlign(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2):
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1):
"""init ROIAlign"""
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
validator.check_value_type("sample_num", sample_num, [int], self.name)
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name)
self.pooled_height = pooled_height
self.pooled_width = pooled_width
self.spatial_scale = spatial_scale
self.sample_num = sample_num
self.roi_end_mode = roi_end_mode
def infer_shape(self, inputs_shape, rois_shape):
return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册