提交 d0ec9c38 编写于 作者: W wjj19950828

fixed for ci

上级 8069fff5
......@@ -21,18 +21,13 @@ from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_d
@paddle.jit.not_to_static
def roi_align(input,
rois,
output_size,
pooled_height,
pooled_width,
spatial_scale=1.0,
sampling_ratio=-1,
rois_num=None,
aligned=True,
name=None):
check_type(output_size, 'output_size', (int, tuple), 'roi_align')
if isinstance(output_size, int):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = _C_ops.roi_align(
......@@ -71,15 +66,15 @@ def roi_align(input,
class ROIAlign(object):
def __init__(self, pooled_height, pooled_width, spatial_scale,
sampling_ratio, rois_num):
sampling_ratio):
self.roialign_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
'sampling_ratio': sampling_ratio,
'rois_num': rois_num,
}
def __call__(self, x0, x1):
out = roi_align(input=x0, rois=x1, **self.roialign_layer_attrs)
def __call__(self, x0, x1, x2):
out = roi_align(
input=x0, rois=x1, rois_num=x2, **self.roialign_layer_attrs)
return out
......@@ -538,12 +538,14 @@ class OpSet9():
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio,
'rois_num': val_rois_num,
}
self.paddle_graph.add_layer(
'custom_layer:ROIAlign',
inputs={'input': val_x.name,
'rois': val_rois.name},
inputs={
'input': val_x.name,
'rois': val_rois.name,
'rois_num': val_rois_num
},
outputs=[node.name],
**layer_attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册