未验证 提交 82e7a90b 编写于 作者: L littletomatodonkey 提交者: GitHub

fix dist eval (#364)

上级 081ca857
......@@ -59,9 +59,14 @@ def main(args, return_dict={}):
paddle.disable_static(place)
strategy = paddle.distributed.init_parallel_env()
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = paddle.DataParallel(net, strategy)
if config["use_data_parallel"]:
strategy = paddle.distributed.init_parallel_env()
net = paddle.DataParallel(net, strategy)
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
net.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册