未验证 提交 5704a255 编写于 作者: L littletomatodonkey 提交者: GitHub

fix eval (#355)

上级 db035f50
......@@ -55,11 +55,14 @@ def main(args, return_dict={}):
place = 'gpu:{}'.format(ParallelEnv().dev_id) if use_gpu else 'cpu'
place = paddle.set_device(place)
paddle.disable_static(place)
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
strategy = paddle.distributed.init_parallel_env()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = paddle.DataParallel(net, strategy)
if config["use_data_parallel"]:
strategy = paddle.distributed.init_parallel_env()
net = paddle.DataParallel(net, strategy)
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
net.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册