提交 94ed5969 编写于 作者: K KP

Add cli logger control.

上级 3151637a
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
...@@ -183,10 +182,15 @@ class ASRExecutor(BaseExecutor): ...@@ -183,10 +182,15 @@ class ASRExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -479,7 +483,9 @@ class ASRExecutor(BaseExecutor): ...@@ -479,7 +483,9 @@ class ASRExecutor(BaseExecutor):
decode_method = parser_args.decode_method decode_method = parser_args.decode_method
force_yes = parser_args.yes force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
...@@ -495,7 +501,7 @@ class ASRExecutor(BaseExecutor): ...@@ -495,7 +501,7 @@ class ASRExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
...@@ -112,10 +111,15 @@ class CLSExecutor(BaseExecutor): ...@@ -112,10 +111,15 @@ class CLSExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -243,7 +247,9 @@ class CLSExecutor(BaseExecutor): ...@@ -243,7 +247,9 @@ class CLSExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
topk = parser_args.topk topk = parser_args.topk
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
...@@ -259,7 +265,7 @@ class CLSExecutor(BaseExecutor): ...@@ -259,7 +265,7 @@ class CLSExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import sys import sys
from abc import ABC from abc import ABC
...@@ -149,10 +150,16 @@ class BaseExecutor(ABC): ...@@ -149,10 +150,16 @@ class BaseExecutor(ABC):
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False. job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
""" """
raw_text = self._format_task_results(results) if not self._is_job_input(input_) and len(
print(raw_text, end='') results) == 1: # Only one input sample
raw_text = list(results.values())[0]
else:
raw_text = self._format_task_results(results)
print(raw_text, end='') # Stdout
if self._is_job_input(input_) and job_dump_result: if self._is_job_input(
input_) and job_dump_result: # Dump to *.job.done
try: try:
job_output_file = os.path.abspath(input_) + '.done' job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w') sys.stdout = open(job_output_file, 'w')
...@@ -209,3 +216,13 @@ class BaseExecutor(ABC): ...@@ -209,3 +216,13 @@ class BaseExecutor(ABC):
for k, v in results.items(): for k, v in results.items():
ret += f'{k} {v}\n' ret += f'{k} {v}\n'
return ret return ret
def disable_task_loggers(self):
"""
Disable all loggers in current task.
"""
loggers = [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]
for l in loggers:
l.disabled = True
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
...@@ -110,10 +109,15 @@ class STExecutor(BaseExecutor): ...@@ -110,10 +109,15 @@ class STExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help="Choose device to execute model inference.") help="Choose device to execute model inference.")
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -327,7 +331,9 @@ class STExecutor(BaseExecutor): ...@@ -327,7 +331,9 @@ class STExecutor(BaseExecutor):
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
...@@ -343,7 +349,7 @@ class STExecutor(BaseExecutor): ...@@ -343,7 +349,7 @@ class STExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
import re import re
from collections import OrderedDict from collections import OrderedDict
...@@ -122,10 +121,15 @@ class TextExecutor(BaseExecutor): ...@@ -122,10 +121,15 @@ class TextExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -270,7 +274,9 @@ class TextExecutor(BaseExecutor): ...@@ -270,7 +274,9 @@ class TextExecutor(BaseExecutor):
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab punc_vocab = parser_args.punc_vocab
device = parser_args.device device = parser_args.device
job_dump_result = parser_args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
...@@ -286,7 +292,7 @@ class TextExecutor(BaseExecutor): ...@@ -286,7 +292,7 @@ class TextExecutor(BaseExecutor):
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results, self.process_task_results(parser_args.input, task_results,
job_dump_result) args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import Any
...@@ -400,10 +399,15 @@ class TTSExecutor(BaseExecutor): ...@@ -400,10 +399,15 @@ class TTSExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument( self.parser.add_argument(
'-d',
'--job_dump_result', '--job_dump_result',
type=ast.literal_eval, action='store_true',
default=False,
help='Save job result into file.') help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
...@@ -693,7 +697,9 @@ class TTSExecutor(BaseExecutor): ...@@ -693,7 +697,9 @@ class TTSExecutor(BaseExecutor):
lang = args.lang lang = args.lang
device = args.device device = args.device
spk_id = args.spk_id spk_id = args.spk_id
job_dump_result = args.job_dump_result
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(args.input) task_source = self.get_task_source(args.input)
task_results = OrderedDict() task_results = OrderedDict()
...@@ -733,7 +739,8 @@ class TTSExecutor(BaseExecutor): ...@@ -733,7 +739,8 @@ class TTSExecutor(BaseExecutor):
has_exceptions = True has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}' task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(args.input, task_results, job_dump_result) self.process_task_results(args.input, task_results,
args.job_dump_result)
if has_exceptions: if has_exceptions:
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册