提交 c4da9a7f 编写于 作者: H Hui Zhang

filter key by class signature, no print tensor

上级 3912c255
......@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm):
super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
......
......@@ -20,7 +20,7 @@ from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import filter_valid_args
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"]
......@@ -80,5 +80,4 @@ class OptimizerFactory():
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
args = filter_valid_args(args)
return module_class(**args)
return instance_class(module_class, args)
......@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import Any
from typing import Dict
from typing import List
from typing import Text
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class", "filter_valid_args"]
__all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()):
......@@ -43,14 +46,22 @@ def dynamic_import(import_path, alias=dict()):
return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any]):
# filter out `val` which is None
new_args = {key: val for key, val in args.items() if val is not None}
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter by `valid_keys` and filter `val` is not None
new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]):
# filter out `val` which is None
new_args = filter_valid_args(args)
logger.info(f"Instance: {module_class.__name__} {new_args}.")
valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args, valid_keys)
logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args)
......@@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册