diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index d0f9803d2664f697e58ea4ed2087d5c44526e1f9..f46814eb0ae04887ceb4c7c1f674fc360f3644c0 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): def __init__(self, clip_norm): super().__init__(clip_norm) + def __repr__(self): + return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})" + @imperative_base.no_grad def _dygraph_clip(self, params_grads): params_and_grads = [] diff --git a/deepspeech/training/optimizer.py b/deepspeech/training/optimizer.py index 2e62a7ed71da8b048538d111c9ed1f5144baf131..f7933f8d4f986ea87302330691d6078f2c51c0de 100644 --- a/deepspeech/training/optimizer.py +++ b/deepspeech/training/optimizer.py @@ -20,7 +20,7 @@ from paddle.regularizer import L2Decay from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.utils.dynamic_import import dynamic_import -from deepspeech.utils.dynamic_import import filter_valid_args +from deepspeech.utils.dynamic_import import instance_class from deepspeech.utils.log import Log __all__ = ["OptimizerFactory"] @@ -80,5 +80,4 @@ class OptimizerFactory(): args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) - args = filter_valid_args(args) - return module_class(**args) + return instance_class(module_class, args) diff --git a/deepspeech/utils/dynamic_import.py b/deepspeech/utils/dynamic_import.py index 41978bc93f97bbcd34f4719ecefb964900a36c0a..533f15eeefdae8b7d13b8d8742cb5cbddc7696db 100644 --- a/deepspeech/utils/dynamic_import.py +++ b/deepspeech/utils/dynamic_import.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import inspect from typing import Any from typing import Dict +from typing import List from typing import Text from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import has_tensor logger = Log(__name__).getlog() -__all__ = ["dynamic_import", "instance_class", "filter_valid_args"] +__all__ = ["dynamic_import", "instance_class"] def dynamic_import(import_path, alias=dict()): @@ -43,14 +46,22 @@ def dynamic_import(import_path, alias=dict()): return getattr(m, objname) -def filter_valid_args(args: Dict[Text, Any]): - # filter out `val` which is None - new_args = {key: val for key, val in args.items() if val is not None} +def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]): + # filter by `valid_keys` and filter `val` is not None + new_args = { + key: val + for key, val in args.items() if (key in valid_keys and val is not None) + } return new_args +def filter_out_tenosr(args: Dict[Text, Any]): + return {key: val for key, val in args.items() if not has_tensor(val)} + + def instance_class(module_class, args: Dict[Text, Any]): - # filter out `val` which is None - new_args = filter_valid_args(args) - logger.info(f"Instance: {module_class.__name__} {new_args}.") + valid_keys = inspect.signature(module_class).parameters.keys() + new_args = filter_valid_args(args, valid_keys) + logger.info( + f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.") return module_class(**new_args) diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 7679d9e1c573a552eaf0086da12b19dd018c9024..9bff6b0f3f35ccb7a5392481efd2da7fb9e70ea1 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -19,11 +19,25 @@ import paddle from deepspeech.utils.log import Log -__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"] logger = Log(__name__).getlog() +def has_tensor(val): + if isinstance(val, (list, tuple)): + for item in val: + if has_tensor(item): + return True + elif isinstance(val, dict): + for k, v in val.items(): + print(k) + if has_tensor(v): + return True + else: + return paddle.is_tensor(val) + + def pad_sequence(sequences: List[paddle.Tensor], batch_first: bool=False, padding_value: float=0.0) -> paddle.Tensor: