提交 d0ec9c38 编写于 作者: W wjj19950828

fixed for ci

上级 8069fff5
...@@ -21,18 +21,13 @@ from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_d ...@@ -21,18 +21,13 @@ from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_d
@paddle.jit.not_to_static @paddle.jit.not_to_static
def roi_align(input, def roi_align(input,
rois, rois,
output_size, pooled_height,
pooled_width,
spatial_scale=1.0, spatial_scale=1.0,
sampling_ratio=-1, sampling_ratio=-1,
rois_num=None, rois_num=None,
aligned=True, aligned=True,
name=None): name=None):
check_type(output_size, 'output_size', (int, tuple), 'roi_align')
if isinstance(output_size, int):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
if in_dynamic_mode(): if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode." assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = _C_ops.roi_align( align_out = _C_ops.roi_align(
...@@ -71,15 +66,15 @@ def roi_align(input, ...@@ -71,15 +66,15 @@ def roi_align(input,
class ROIAlign(object): class ROIAlign(object):
def __init__(self, pooled_height, pooled_width, spatial_scale, def __init__(self, pooled_height, pooled_width, spatial_scale,
sampling_ratio, rois_num): sampling_ratio):
self.roialign_layer_attrs = { self.roialign_layer_attrs = {
"pooled_height": pooled_height, "pooled_height": pooled_height,
"pooled_width": pooled_width, "pooled_width": pooled_width,
"spatial_scale": spatial_scale, "spatial_scale": spatial_scale,
'sampling_ratio': sampling_ratio, 'sampling_ratio': sampling_ratio,
'rois_num': rois_num,
} }
def __call__(self, x0, x1): def __call__(self, x0, x1, x2):
out = roi_align(input=x0, rois=x1, **self.roialign_layer_attrs) out = roi_align(
input=x0, rois=x1, rois_num=x2, **self.roialign_layer_attrs)
return out return out
...@@ -538,12 +538,14 @@ class OpSet9(): ...@@ -538,12 +538,14 @@ class OpSet9():
'pooled_width': pooled_width, 'pooled_width': pooled_width,
'spatial_scale': spatial_scale, 'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio, 'sampling_ratio': sampling_ratio,
'rois_num': val_rois_num,
} }
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'custom_layer:ROIAlign', 'custom_layer:ROIAlign',
inputs={'input': val_x.name, inputs={
'rois': val_rois.name}, 'input': val_x.name,
'rois': val_rois.name,
'rois_num': val_rois_num
},
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册