提交 733b6c21 编写于 作者: W wuzewu

Update trainer

上级 61dcfbe9
......@@ -32,7 +32,7 @@ class Trainer(object):
Args:
model(paddle.nn.Layer) : Model to train or evaluate.
strategy(paddle.optimizer.Optimizer) : Optimizer strategy.
optimizer(paddle.optimizer.Optimizer) : Optimizer for loss.
use_vdl(bool) : Whether to use visualdl to record training data.
checkpoint_dir(str) : Directory where the checkpoint is saved, and the trainer will restore the
state and model parameters from the checkpoint.
......@@ -51,16 +51,19 @@ class Trainer(object):
def __init__(self,
model: paddle.nn.Layer,
strategy: paddle.optimizer.Optimizer,
optimizer: paddle.optimizer.Optimizer,
use_vdl: bool = True,
checkpoint_dir: str = None,
compare_metrics: Callable = None):
self.nranks = paddle.distributed.get_world_size()
self.local_rank = paddle.distributed.get_rank()
self.model = model
self.optimizer = strategy
self.optimizer = optimizer
self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time())
if not isinstance(self.model, paddle.nn.Layer):
raise TypeError('The model {} is not a `paddle.nn.Layer` object.'.format(self.model.__name__))
if self.local_rank == 0 and not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册