diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index a40950d9ff5327b2856eff601425483e600ece59..b19233fdeecfbb9084a65c66795eabd81440d928 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -74,8 +74,8 @@ def deformable_attention_core_func(value, value_spatial_shapes, """ Args: value (Tensor): [bs, value_length, n_head, c] - value_spatial_shapes (Tensor): [n_levels, 2] - value_level_start_index (Tensor): [n_levels] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] @@ -85,8 +85,8 @@ def deformable_attention_core_func(value, value_spatial_shapes, bs, _, n_head, c = value.shape _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape - value_list = value.split( - value_spatial_shapes.prod(1).split(n_levels), axis=1) + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, axis=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level, (h, w) in enumerate(value_spatial_shapes):