提交 97f99cd8 编写于 作者: H HydrogenSulfate 提交者: Walter

refactor(retrieval): polish retrieval.py

上级 8542967b
......@@ -13,15 +13,16 @@
# limitations under the License.
from . import logger
from . import metrics
from . import misc
from . import model_zoo
from . import metrics
from .save_load import init_model, save_model
from .config import get_config
from .misc import AverageMeter
from .metrics import multi_hot_encode
from .metrics import hamming_distance
from .dist_utils import all_gather
from .metrics import accuracy_score
from .metrics import precision_recall_fscore
from .metrics import hamming_distance
from .metrics import mean_average_precision
from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model, save_model
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import paddle
def all_gather(tensor: paddle.Tensor, concat: bool=True,
axis: int=0) -> Union[paddle.Tensor, List[paddle.Tensor]]:
"""Gather tensor from all devices, concatenate them along given axis if specified.
Args:
tensor (paddle.Tensor): Tensor to be gathered from all GPUs.
concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True.
axis (int, optional): Axis which concatenated along. Defaults to 0.
Returns:
Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors
"""
result = []
paddle.distributed.all_gather(result, tensor)
if concat:
return paddle.concat(result, axis)
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册