提交 c4da9a7f 编写于 作者: H Hui Zhang

filter key by class signature, no print tensor

上级 3912c255
...@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm): def __init__(self, clip_norm):
super().__init__(clip_norm) super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
......
...@@ -20,7 +20,7 @@ from paddle.regularizer import L2Decay ...@@ -20,7 +20,7 @@ from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import filter_valid_args from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"] __all__ = ["OptimizerFactory"]
...@@ -80,5 +80,4 @@ class OptimizerFactory(): ...@@ -80,5 +80,4 @@ class OptimizerFactory():
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
args = filter_valid_args(args) return instance_class(module_class, args)
return module_class(**args)
...@@ -12,15 +12,18 @@ ...@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import inspect
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List
from typing import Text from typing import Text
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class", "filter_valid_args"] __all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()): def dynamic_import(import_path, alias=dict()):
...@@ -43,14 +46,22 @@ def dynamic_import(import_path, alias=dict()): ...@@ -43,14 +46,22 @@ def dynamic_import(import_path, alias=dict()):
return getattr(m, objname) return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any]): def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter out `val` which is None # filter by `valid_keys` and filter `val` is not None
new_args = {key: val for key, val in args.items() if val is not None} new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]): def instance_class(module_class, args: Dict[Text, Any]):
# filter out `val` which is None valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args) new_args = filter_valid_args(args, valid_keys)
logger.info(f"Instance: {module_class.__name__} {new_args}.") logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args) return module_class(**new_args)
...@@ -19,11 +19,25 @@ import paddle ...@@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor: padding_value: float=0.0) -> paddle.Tensor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册