未验证 提交 1e417485 编写于 作者: W wangguanzhong 提交者: GitHub

update for paddle 2.0rc (#1583)

上级 6ac9743c
...@@ -52,7 +52,7 @@ TrainReader: ...@@ -52,7 +52,7 @@ TrainReader:
drop_last: true drop_last: true
worker_num: 4 worker_num: 4
bufsize: 4 bufsize: 4
use_process: true use_process: false #true
EvalReader: EvalReader:
......
...@@ -92,7 +92,7 @@ class YOLOFeat(nn.Layer): ...@@ -92,7 +92,7 @@ class YOLOFeat(nn.Layer):
if i < self.num_levels - 1: if i < self.num_levels - 1:
route = self.route_blocks[i](route) route = self.route_blocks[i](route)
route = F.resize_nearest(route, scale=2.) route = F.interpolate(route, scale_factor=2.)
return yolo_feats return yolo_feats
......
...@@ -23,7 +23,7 @@ from ppdet.utils.stats import TrainingStats ...@@ -23,7 +23,7 @@ from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
from paddle.distributed import ParallelEnv import paddle.distributed as dist
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
...@@ -87,8 +87,16 @@ def parse_args(): ...@@ -87,8 +87,16 @@ def parse_args():
return args return args
def run(FLAGS, cfg): def run():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
env = os.environ env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist: if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID']) trainer_id = int(env['PADDLE_TRAINER_ID'])
...@@ -117,7 +125,7 @@ def run(FLAGS, cfg): ...@@ -117,7 +125,7 @@ def run(FLAGS, cfg):
load_static_weights=cfg.get('load_static_weights', False)) load_static_weights=cfg.get('load_static_weights', False))
# Parallel Model # Parallel Model
if ParallelEnv().nranks > 1: if dist.ParallelEnv().nranks > 1:
strategy = paddle.distributed.init_parallel_env() strategy = paddle.distributed.init_parallel_env()
model = paddle.DataParallel(model, strategy) model = paddle.DataParallel(model, strategy)
...@@ -151,7 +159,7 @@ def run(FLAGS, cfg): ...@@ -151,7 +159,7 @@ def run(FLAGS, cfg):
# Model Backward # Model Backward
loss = outputs['loss'] loss = outputs['loss']
if ParallelEnv().nranks > 1: if dist.ParallelEnv().nranks > 1:
loss = model.scale_loss(loss) loss = model.scale_loss(loss)
loss.backward() loss.backward()
model.apply_collective_grads() model.apply_collective_grads()
...@@ -163,7 +171,7 @@ def run(FLAGS, cfg): ...@@ -163,7 +171,7 @@ def run(FLAGS, cfg):
lr.step() lr.step()
optimizer.clear_grad() optimizer.clear_grad()
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.ParallelEnv().nranks < 2 or dist.ParallelEnv().local_rank == 0:
# Log state # Log state
if iter_id == 0: if iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys()) train_stats = TrainingStats(cfg.log_iter, outputs.keys())
...@@ -185,19 +193,7 @@ def run(FLAGS, cfg): ...@@ -185,19 +193,7 @@ def run(FLAGS, cfg):
def main(): def main():
FLAGS = parse_args() dist.spawn(run)
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if cfg.use_gpu else paddle.CPUPlace()
paddle.disable_static(place)
run(FLAGS, cfg)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册