未验证 提交 1e417485 编写于 作者: W wangguanzhong 提交者: GitHub

update for paddle 2.0rc (#1583)

上级 6ac9743c
......@@ -52,7 +52,7 @@ TrainReader:
drop_last: true
worker_num: 4
bufsize: 4
use_process: true
use_process: false #true
EvalReader:
......
......@@ -92,7 +92,7 @@ class YOLOFeat(nn.Layer):
if i < self.num_levels - 1:
route = self.route_blocks[i](route)
route = F.resize_nearest(route, scale=2.)
route = F.interpolate(route, scale_factor=2.)
return yolo_feats
......
......@@ -23,7 +23,7 @@ from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
from paddle.distributed import ParallelEnv
import paddle.distributed as dist
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
......@@ -87,8 +87,16 @@ def parse_args():
return args
def run(FLAGS, cfg):
def run():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
if FLAGS.dist:
trainer_id = int(env['PADDLE_TRAINER_ID'])
......@@ -117,7 +125,7 @@ def run(FLAGS, cfg):
load_static_weights=cfg.get('load_static_weights', False))
# Parallel Model
if ParallelEnv().nranks > 1:
if dist.ParallelEnv().nranks > 1:
strategy = paddle.distributed.init_parallel_env()
model = paddle.DataParallel(model, strategy)
......@@ -151,7 +159,7 @@ def run(FLAGS, cfg):
# Model Backward
loss = outputs['loss']
if ParallelEnv().nranks > 1:
if dist.ParallelEnv().nranks > 1:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
......@@ -163,7 +171,7 @@ def run(FLAGS, cfg):
lr.step()
optimizer.clear_grad()
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.ParallelEnv().nranks < 2 or dist.ParallelEnv().local_rank == 0:
# Log state
if iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
......@@ -185,19 +193,7 @@ def run(FLAGS, cfg):
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if cfg.use_gpu else paddle.CPUPlace()
paddle.disable_static(place)
run(FLAGS, cfg)
dist.spawn(run)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册