提交 f49029cc 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #356 from lcy-seso/fix_print_gradient_in_GNR

fix a bug in printing parameter gradients during training for GNR.
#!/usr/bin/env python
#coding=utf-8
import collections
......
#!/usr/bin/env python
#coding=utf-8
import numpy as np
......
#!/usr/bin/env python
#coding=utf-8
__all__ = ["ModelConfig", "TrainerConfig"]
......@@ -36,11 +35,12 @@ class TrainerConfig(object):
epochs = 20
# for debug print, if set to 0, no information will be printed.
# This parameter is for debug printing.
# If it set to 0, no information will be printed.
show_parameter_status_period = 0
checkpoint_period = 100
log_period = 1
log_period = 5
# this is used to resume training, this path can set to previously
# trained model.
# This parameter is used to resume training.
# This path can be set to a previously trained model.
init_model_path = None
# -*- coding: utf-8 -*-
#coding=utf-8
"""
Convert the raw json data into training and validation examples.
"""
......
#!/usr/bin/env python
#coding=utf-8
import os
......
#!/usr/bin/env python
#coding=utf-8
from __future__ import print_function
......@@ -147,22 +146,27 @@ def build_event_handler(config, parameters, trainer):
# End batch and end pass event handler
def event_handler(event):
"""The event handler."""
"""
To print the statistical information of gradients of any learnable
parameter, the event: EndForwardBackward rather than EndIteration
should be handled. For the reason that parameter gradients will be
reset to zeros when EndIteration event happens in GPU training.
"""
if config.show_parameter_status_period and \
isinstance(event, paddle.event.EndForwardBackward):
if not event.batch_id % config.show_parameter_status_period:
show_parameter_status(parameters)
if isinstance(event, paddle.event.EndIteration):
if event.batch_id and \
(not event.batch_id % config.checkpoint_period):
if event.batch_id and not event.batch_id % config.checkpoint_period:
save_path = os.path.join(config.save_dir,
"checkpoint_param.latest.tar.gz")
save_model(save_path, parameters)
if event.batch_id and not event.batch_id % config.log_period:
if not event.batch_id % config.log_period:
logger.info("Pass %d, Batch %d, Cost %f" %
(event.pass_id, event.batch_id, event.cost))
if config.show_parameter_status_period and event.batch_id and \
not (event.batch_id % config.show_parameter_status_period):
show_parameter_status(parameters)
if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(config.save_dir,
"pass_%05d.tar.gz" % event.pass_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册