未验证 提交 c3d401b7 编写于 作者: vslyu's avatar vslyu 提交者: GitHub

add multi xpu support for PaddleClas (#678)

上级 a7aa1452
......@@ -119,11 +119,12 @@ def main(args):
init_model(config, train_prog, exe)
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
optimizer.amp_init(place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
optimizer.amp_init(
place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
if not config.get("is_distributed", True) and not use_xpu:
if not config.get("is_distributed", True):
compiled_train_prog = program.compile(
config, train_prog, loss_name=train_fetchs["loss"][0].name)
else:
......@@ -133,10 +134,7 @@ def main(args):
train_dataloader = Reader(config, 'train', places=place)()
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = Reader(config, 'valid', places=place)()
if use_xpu:
compiled_valid_prog = valid_prog
else:
compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = program.compile(config, valid_prog)
else:
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
import dali
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册