From 63f0721b2d34034a1be6a0104896e1a75c7bad94 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Thu, 9 Mar 2023 20:57:30 +0800 Subject: [PATCH] modify the deformable_attention_core_func code to fit paddle-trt (#7900) --- ppdet/modeling/transformers/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ppdet/modeling/transformers/utils.py b/ppdet/modeling/transformers/utils.py index a40950d9f..b19233fde 100644 --- a/ppdet/modeling/transformers/utils.py +++ b/ppdet/modeling/transformers/utils.py @@ -74,8 +74,8 @@ def deformable_attention_core_func(value, value_spatial_shapes, """ Args: value (Tensor): [bs, value_length, n_head, c] - value_spatial_shapes (Tensor): [n_levels, 2] - value_level_start_index (Tensor): [n_levels] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] @@ -85,8 +85,8 @@ def deformable_attention_core_func(value, value_spatial_shapes, bs, _, n_head, c = value.shape _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape - value_list = value.split( - value_spatial_shapes.prod(1).split(n_levels), axis=1) + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, axis=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level, (h, w) in enumerate(value_spatial_shapes): -- GitLab