diff --git a/globally_normalized_reader/basic_modules.py b/globally_normalized_reader/basic_modules.py index 91aefc2f625e12ab8da9232332ad4810abd4e578..bf40e71615e79b0b503cd141fa9989051d5c6c98 100644 --- a/globally_normalized_reader/basic_modules.py +++ b/globally_normalized_reader/basic_modules.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python #coding=utf-8 import collections diff --git a/globally_normalized_reader/beam_decoding.py b/globally_normalized_reader/beam_decoding.py index d072ca17d1a3830c6b56d5fa9c5f2a256952a819..a2ce852046dbe3eb8fffbb5253bf463eb60f279e 100644 --- a/globally_normalized_reader/beam_decoding.py +++ b/globally_normalized_reader/beam_decoding.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python #coding=utf-8 import numpy as np diff --git a/globally_normalized_reader/config.py b/globally_normalized_reader/config.py index 849cc693b646bd131a4f53ca500be48febd1d4ea..2fa48b64d70600a3967895d09b22b8b23ffcd11b 100644 --- a/globally_normalized_reader/config.py +++ b/globally_normalized_reader/config.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python #coding=utf-8 __all__ = ["ModelConfig", "TrainerConfig"] @@ -36,11 +35,12 @@ class TrainerConfig(object): epochs = 20 - # for debug print, if set to 0, no information will be printed. + # This parameter is for debug printing. + # If it set to 0, no information will be printed. show_parameter_status_period = 0 checkpoint_period = 100 - log_period = 1 + log_period = 5 - # this is used to resume training, this path can set to previously - # trained model. + # This parameter is used to resume training. + # This path can be set to a previously trained model. init_model_path = None diff --git a/globally_normalized_reader/featurize.py b/globally_normalized_reader/featurize.py index d0eb9d626b2b41436b319b7b0a2cac284107ed98..d7dd09bc6b0e9188b6589383841cb06eac18e7af 100644 --- a/globally_normalized_reader/featurize.py +++ b/globally_normalized_reader/featurize.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +#coding=utf-8 """ Convert the raw json data into training and validation examples. """ diff --git a/globally_normalized_reader/infer.py b/globally_normalized_reader/infer.py index 351b2659fb974123feed1e16636171399ae89ea7..397cb8ce9497e23919eb3b5752b180f4c67ce394 100644 --- a/globally_normalized_reader/infer.py +++ b/globally_normalized_reader/infer.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python #coding=utf-8 import os diff --git a/globally_normalized_reader/train.py b/globally_normalized_reader/train.py index e377fa1c98b8bfcbd5014d769270146febc96ace..b06b9720adfcbb2a30e32ba6bd61f7a0c04dde72 100644 --- a/globally_normalized_reader/train.py +++ b/globally_normalized_reader/train.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python #coding=utf-8 from __future__ import print_function @@ -147,22 +146,27 @@ def build_event_handler(config, parameters, trainer): # End batch and end pass event handler def event_handler(event): """The event handler.""" + """ + To print the statistical information of gradients of any learnable + parameter, the event: EndForwardBackward rather than EndIteration + should be handled. For the reason that parameter gradients will be + reset to zeros when EndIteration event happens in GPU training. + """ + if config.show_parameter_status_period and \ + isinstance(event, paddle.event.EndForwardBackward): + if not event.batch_id % config.show_parameter_status_period: + show_parameter_status(parameters) if isinstance(event, paddle.event.EndIteration): - if event.batch_id and \ - (not event.batch_id % config.checkpoint_period): + if event.batch_id and not event.batch_id % config.checkpoint_period: save_path = os.path.join(config.save_dir, "checkpoint_param.latest.tar.gz") save_model(save_path, parameters) - if event.batch_id and not event.batch_id % config.log_period: + if not event.batch_id % config.log_period: logger.info("Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id, event.cost)) - if config.show_parameter_status_period and event.batch_id and \ - not (event.batch_id % config.show_parameter_status_period): - show_parameter_status(parameters) - if isinstance(event, paddle.event.EndPass): save_path = os.path.join(config.save_dir, "pass_%05d.tar.gz" % event.pass_id)