未验证 提交 63f0721b 编写于 作者: S shangliang Xu 提交者: GitHub

modify the deformable_attention_core_func code to fit paddle-trt (#7900)

上级 67df6a1e
......@@ -74,8 +74,8 @@ def deformable_attention_core_func(value, value_spatial_shapes,
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor): [n_levels, 2]
value_level_start_index (Tensor): [n_levels]
value_spatial_shapes (Tensor|List): [n_levels, 2]
value_level_start_index (Tensor|List): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
......@@ -85,8 +85,8 @@ def deformable_attention_core_func(value, value_spatial_shapes,
bs, _, n_head, c = value.shape
_, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
value_list = value.split(
value_spatial_shapes.prod(1).split(n_levels), axis=1)
split_shape = [h * w for h, w in value_spatial_shapes]
value_list = value.split(split_shape, axis=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册