diff --git a/ppdet/data/sampler.py b/ppdet/data/sampler.py index 1ad79bde1ce6249fbd03d0f57a363c54207c576a..c07b5fe4beff5b2f5b783639cfe977b5de8411cc 100644 --- a/ppdet/data/sampler.py +++ b/ppdet/data/sampler.py @@ -7,11 +7,8 @@ import socket import contextlib import numpy as np -from paddle import fluid from paddle.io import BatchSampler -from paddle.fluid.layers import collective from paddle.distributed import ParallelEnv -from paddle.fluid.dygraph.parallel import ParallelStrategy _parallel_context_initialized = False @@ -85,95 +82,3 @@ class DistributedBatchSampler(BatchSampler): def set_epoch(self, epoch): self.epoch = epoch - - -def wait_server_ready(endpoints): - assert not isinstance(endpoints, six.string_types) - while True: - all_ok = True - not_ready_endpoints = [] - for ep in endpoints: - ip_port = ep.split(":") - with contextlib.closing( - socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - sock.settimeout(2) - result = sock.connect_ex((ip_port[0], int(ip_port[1]))) - if result != 0: - all_ok = False - not_ready_endpoints.append(ep) - if not all_ok: - time.sleep(3) - else: - break - - -def init_communicator(program, rank, nranks, wait_port, current_endpoint, - endpoints): - if nranks < 2: - return - other_endpoints = endpoints[:] - other_endpoints.remove(current_endpoint) - if rank == 0 and wait_port: - wait_server_ready(other_endpoints) - block = program.global_block() - nccl_id_var = block.create_var( - name=fluid.unique_name.generate('nccl_id'), - persistable=True, - type=fluid.core.VarDesc.VarType.RAW) - - block.append_op( - type='c_gen_nccl_id', - inputs={}, - outputs={'Out': nccl_id_var}, - attrs={ - 'rank': rank, - 'endpoint': current_endpoint, - 'other_endpoints': other_endpoints - }) - - block.append_op( - type='c_comm_init', - inputs={'X': nccl_id_var}, - outputs={}, - attrs={ - 'nranks': nranks, - 'rank': rank, - 'ring_id': 0, - }) - - -def prepare_distributed_context(place=None): - if place is None: - place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \ - else fluid.CUDAPlace(0) - - strategy = ParallelStrategy() - strategy.nranks = ParallelEnv().nranks - strategy.local_rank = ParallelEnv().local_rank - strategy.trainer_endpoints = ParallelEnv().trainer_endpoints - strategy.current_endpoint = ParallelEnv().current_endpoint - - if strategy.nranks < 2: - return - - global _parallel_context_initialized - - if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace): - - def _init_context(): - communicator_prog = fluid.Program() - init_communicator(communicator_prog, strategy.local_rank, - strategy.nranks, True, strategy.current_endpoint, - strategy.trainer_endpoints) - exe = fluid.Executor(place) - exe.run(communicator_prog) - - fluid.disable_dygraph() - _init_context() - fluid.enable_dygraph(place) - - else: - assert ("Only support CUDAPlace for now.") - - _parallel_context_initialized = True - return strategy diff --git a/ppdet/modeling/architecture/meta_arch.py b/ppdet/modeling/architecture/meta_arch.py index 06e27bca93c278f8b18b6c60cdfde4a077bf3025..2731b3d456ff0fc08fb71bb57538f3715b9ca26f 100644 --- a/ppdet/modeling/architecture/meta_arch.py +++ b/ppdet/modeling/architecture/meta_arch.py @@ -6,7 +6,6 @@ import numpy as np import paddle import paddle.nn as nn from ppdet.core.workspace import register -from ppdet.utils.data_structure import BufferDict __all__ = ['BaseArch'] diff --git a/ppdet/modeling/bbox.py b/ppdet/modeling/bbox.py index 771f59c7d583b330f05bd58a3711c218e5fabb33..9c494d45741a4562d5dda4359c96787c346240bd 100644 --- a/ppdet/modeling/bbox.py +++ b/ppdet/modeling/bbox.py @@ -1,5 +1,4 @@ import numpy as np -import paddle.fluid as fluid import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -29,20 +28,20 @@ class Anchor(object): rpn_delta_list = [] anchor_list = [] for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_feats, anchors): - rpn_score = fluid.layers.transpose(rpn_score, perm=[0, 2, 3, 1]) - rpn_delta = fluid.layers.transpose(rpn_delta, perm=[0, 2, 3, 1]) - rpn_score = fluid.layers.reshape(x=rpn_score, shape=(0, -1, 1)) - rpn_delta = fluid.layers.reshape(x=rpn_delta, shape=(0, -1, 4)) + rpn_score = paddle.transpose(rpn_score, perm=[0, 2, 3, 1]) + rpn_delta = paddle.transpose(rpn_delta, perm=[0, 2, 3, 1]) + rpn_score = paddle.reshape(x=rpn_score, shape=(0, -1, 1)) + rpn_delta = paddle.reshape(x=rpn_delta, shape=(0, -1, 4)) - anchor = fluid.layers.reshape(anchor, shape=(-1, 4)) - var = fluid.layers.reshape(var, shape=(-1, 4)) + anchor = paddle.reshape(anchor, shape=(-1, 4)) + var = paddle.reshape(var, shape=(-1, 4)) rpn_score_list.append(rpn_score) rpn_delta_list.append(rpn_delta) anchor_list.append(anchor) - rpn_scores = fluid.layers.concat(rpn_score_list, axis=1) - rpn_deltas = fluid.layers.concat(rpn_delta_list, axis=1) - anchors = fluid.layers.concat(anchor_list) + rpn_scores = paddle.concat(rpn_score_list, axis=1) + rpn_deltas = paddle.concat(rpn_delta_list, axis=1) + anchors = paddle.concat(anchor_list) return rpn_scores, rpn_deltas, anchors def generate_loss_inputs(self, inputs, rpn_head_out, anchors): @@ -102,7 +101,7 @@ class Proposal(object): rpn_rois_num_list = [] for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_head_out, anchor_out): - rpn_prob = fluid.layers.sigmoid(rpn_score) + rpn_prob = F.sigmoid(rpn_score) rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = self.proposal_generator( scores=rpn_prob, bbox_deltas=rpn_delta, diff --git a/ppdet/modeling/mask.py b/ppdet/modeling/mask.py index 8911e34be9877ad2226d89adcdb74f03a35a3efe..c4d1a0b6222d2c5e0d4d38501ceb18d586cb89a4 100644 --- a/ppdet/modeling/mask.py +++ b/ppdet/modeling/mask.py @@ -1,5 +1,4 @@ import numpy as np -import paddle.fluid as fluid from ppdet.core.workspace import register diff --git a/ppdet/modeling/neck/fpn.py b/ppdet/modeling/neck/fpn.py index 318c83ca09d1077dccd002bb8ef8317dc3232e42..5565bfdddf664a4e399b60d6043d3899bb9e4f19 100644 --- a/ppdet/modeling/neck/fpn.py +++ b/ppdet/modeling/neck/fpn.py @@ -14,13 +14,12 @@ import numpy as np import paddle -import paddle.fluid as fluid import paddle.nn.functional as F from paddle import ParamAttr from paddle.nn import Layer from paddle.nn import Conv2D from paddle.nn.initializer import XavierUniform -from paddle.fluid.regularizer import L2Decay +from paddle.regularizer import L2Decay from ppdet.core.workspace import register, serializable diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index bf321f0535a343be6f4d64b5b847f13986de490a..21e9e4ce723a657b3981c758a4ce82da07b7c00b 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -23,7 +23,7 @@ import paddle import paddle.nn as nn import paddle.optimizer as optimizer -import paddle.fluid.regularizer as regularizer +import paddle.regularizer as regularizer from paddle import cos from ppdet.core.workspace import register, serializable diff --git a/ppdet/utils/bbox_utils.py b/ppdet/utils/bbox_utils.py index ff16e8b9d101a49ac250e7834f2fd6b3d2c22015..c6204d73bc141cd9208127f164cfac3f0f934d9f 100644 --- a/ppdet/utils/bbox_utils.py +++ b/ppdet/utils/bbox_utils.py @@ -19,8 +19,6 @@ from __future__ import print_function import logging import numpy as np -import paddle.fluid as fluid - __all__ = ["bbox_overlaps", "box_to_delta"] logger = logging.getLogger(__name__) diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index 382f99ed354c6358dec5a6098f1fe6730a11d527..32f889c03851430610bd35fbf3ff0afbc0a368fb 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -19,7 +19,6 @@ from __future__ import print_function import sys import paddle -from paddle import fluid import logging import six import paddle.version as fluid_version @@ -40,7 +39,7 @@ def check_gpu(use_gpu): "model on CPU" try: - if use_gpu and not fluid.is_compiled_with_cuda(): + if use_gpu and not paddle.is_compiled_with_cuda(): logger.error(err) sys.exit(1) except Exception as e: diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index a280cbeed679458c4357be439994a4ba0e548540..2452aec566274c377f196672e8d6576b3c9259c4 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -23,7 +23,6 @@ import time import re import numpy as np import paddle -import paddle.fluid as fluid from .download import get_weights_path import logging logger = logging.getLogger(__name__) diff --git a/ppdet/utils/dist_utils.py b/ppdet/utils/dist_utils.py deleted file mode 100644 index 32eead4a797ba70cb6980e0368ff9873102680c2..0000000000000000000000000000000000000000 --- a/ppdet/utils/dist_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import - -import os - -import paddle.fluid as fluid - - -def nccl2_prepare(trainer_id, startup_prog, main_prog): - config = fluid.DistributeTranspilerConfig() - config.mode = "nccl2" - t = fluid.DistributeTranspiler(config=config) - t.transpile( - trainer_id, - trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'), - current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'), - startup_program=startup_prog, - program=main_prog) - - -def prepare_for_multi_process(exe, build_strategy, startup_prog, main_prog): - trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) - num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - if num_trainers < 2: - return - build_strategy.num_trainers = num_trainers - build_strategy.trainer_id = trainer_id - nccl2_prepare(trainer_id, startup_prog, main_prog) diff --git a/ppdet/utils/post_process.py b/ppdet/utils/post_process.py index cf251998348504332675799bd3d3ccfb342d515c..ad8f2f23365c0d09941642a5bc558ff2360a42f0 100644 --- a/ppdet/utils/post_process.py +++ b/ppdet/utils/post_process.py @@ -19,7 +19,6 @@ from __future__ import print_function import logging import numpy as np import cv2 -import paddle.fluid as fluid __all__ = ['nms'] diff --git a/tools/export_utils.py b/tools/export_utils.py index 48068b5e60519d7954c850a6d58bba77386914a0..61a0c0dcce807f8ec9a6030b202fced2c201960a 100644 --- a/tools/export_utils.py +++ b/tools/export_utils.py @@ -24,8 +24,6 @@ from collections import OrderedDict import logging logger = logging.getLogger(__name__) -import paddle.fluid as fluid - __all__ = ['dump_infer_config', 'save_infer_model'] # Global dictionary