diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 59bb18bfcab30540bd38ca8d1cb300813d30fee8..bf0208834600fef3bcf1b0496da8f5f77aea44c5 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -19,34 +19,34 @@ __all__ = [ def convert_and_compare(x, Type): - """ - Convert x to be the same type as Type and then convert back to - check whether there is a loss of information - :param x: object to be checked - :param Type: target type to check x over - + """ + Convert x to be the same type as Type and then convert back to + check whether there is a loss of information + :param x: object to be checked + :param Type: target type to check x over + """ return type(x)(Type(x)) == x def is_compatible_with(x, Type): - """ - Check if x has a type compatible with Type - :param x: object to be checked - :param Type: target type to check x over - + """ + Check if x has a type compatible with Type + :param x: object to be checked + :param Type: target type to check x over + """ if type(x) == Type: return True try: if float == Type or int == Type: - # avoid those types that can be converted to float/int but not very - # meaningful and could potentially lead to error - # i.e., str and bool typed value should not be used for initializing float/int variable + # avoid those types that can be converted to float/int but not very + # meaningful and could potentially lead to error + # i.e., str and bool typed value should not be used for initializing float/int variable if not isinstance(x, str) and not isinstance(x, bool): return convert_and_compare(x, Type) elif bool == Type: - # should not use string type to initialize bool variable + # should not use string type to initialize bool variable if not isinstance(x, str): return convert_and_compare(x, Type) else: @@ -88,6 +88,10 @@ class ParameterAttribute(object): :type learning_rate: float or None :param momentum: The parameter momentum. None means use global value. :type momentum: float or None + :param gradient_clipping_threshold: gradient clipping threshold. If gradient + value larger than some value, will be + clipped. + :type gradient_clipping_threshold: float :param sparse_update: Enable sparse update for this parameter. It will enable both local and remote sparse update. :type sparse_update: bool @@ -104,6 +108,7 @@ class ParameterAttribute(object): l2_rate=None, learning_rate=None, momentum=None, + gradient_clipping_threshold=None, sparse_update=False): # initialize strategy. if is_static: @@ -152,6 +157,11 @@ class ParameterAttribute(object): self.attr['sparse_update'] = True self.attr['sparse_remote_update'] = True + if gradient_clipping_threshold is not None and \ + is_compatible_with(gradient_clipping_threshold, float): + self.attr['gradient_clipping_threshold'] = \ + gradient_clipping_threshold + def set_default_parameter_name(self, name): """ Set default parameter name. If parameter not set, then will use default