提交 b5a4bb1b 编写于 作者: W wuzewu

Fix encoding bug

上级 d34e9473
...@@ -22,6 +22,7 @@ import sys ...@@ -22,6 +22,7 @@ import sys
import requests import requests
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.utils import sys_stdin_encoding
from paddlehub.common import stats from paddlehub.common import stats
from paddlehub.commands.base_command import BaseCommand from paddlehub.commands.base_command import BaseCommand
from paddlehub.commands import show from paddlehub.commands import show
...@@ -63,7 +64,7 @@ def main(): ...@@ -63,7 +64,7 @@ def main():
argv = [] argv = []
for item in sys.argv: for item in sys.argv:
if six.PY2: if six.PY2:
argv.append(item.decode(sys.stdin.encoding).decode("utf8")) argv.append(item.decode(sys_stdin_encoding()).decode("utf8"))
else: else:
argv.append(item) argv.append(item)
command.execute(argv[1:]) command.execute(argv[1:])
...@@ -73,7 +74,7 @@ if __name__ == "__main__": ...@@ -73,7 +74,7 @@ if __name__ == "__main__":
argv = [] argv = []
for item in sys.argv: for item in sys.argv:
if six.PY2: if six.PY2:
argv.append(item.decode(sys.stdin.encoding).decode("utf8")) argv.append(item.decode(sys_stdin_encoding()).decode("utf8"))
else: else:
argv.append(item) argv.append(item)
command.execute(argv[1:]) command.execute(argv[1:])
...@@ -17,6 +17,7 @@ from __future__ import absolute_import ...@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
import os import os
import time import time
import multiprocessing import multiprocessing
...@@ -231,3 +232,29 @@ def get_running_device_info(config): ...@@ -231,3 +232,29 @@ def get_running_device_info(config):
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
return place, dev_count return place, dev_count
def get_platform_default_encoding():
if platform.platform().lower().startswith("windows"):
return "gbk"
return "utf8"
def sys_stdin_encoding():
encoding = sys.stdin.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
def sys_stdout_encoding():
encoding = sys.stdout.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
...@@ -21,6 +21,8 @@ import codecs ...@@ -21,6 +21,8 @@ import codecs
import sys import sys
import yaml import yaml
from paddlehub.common.utils import sys_stdin_encoding
class CSVFileParser(object): class CSVFileParser(object):
def __init__(self): def __init__(self):
...@@ -30,7 +32,7 @@ class CSVFileParser(object): ...@@ -30,7 +32,7 @@ class CSVFileParser(object):
pass pass
def parse(self, csv_file): def parse(self, csv_file):
with codecs.open(csv_file, "r", sys.stdin.encoding) as file: with codecs.open(csv_file, "r", sys_stdin_encoding()) as file:
content = file.read() content = file.read()
content = content.split('\n') content = content.split('\n')
self.title = content[0].split(',') self.title = content[0].split(',')
...@@ -57,7 +59,7 @@ class YAMLFileParser(object): ...@@ -57,7 +59,7 @@ class YAMLFileParser(object):
pass pass
def parse(self, yaml_file): def parse(self, yaml_file):
with codecs.open(yaml_file, "r", sys.stdin.encoding) as file: with codecs.open(yaml_file, "r", sys_stdin_encoding()) as file:
content = file.read() content = file.read()
return yaml.load(content, Loader=yaml.BaseLoader) return yaml.load(content, Loader=yaml.BaseLoader)
...@@ -70,7 +72,7 @@ class TextFileParser(object): ...@@ -70,7 +72,7 @@ class TextFileParser(object):
pass pass
def parse(self, txt_file): def parse(self, txt_file):
with codecs.open(txt_file, "r", sys.stdin.encoding) as file: with codecs.open(txt_file, "r", sys_stdin_encoding()) as file:
contents = [] contents = []
for line in file: for line in file:
line = line.strip() line = line.strip()
......
...@@ -29,6 +29,7 @@ import paddle ...@@ -29,6 +29,7 @@ import paddle
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.utils import sys_stdout_encoding
from paddlehub.dataset.dataset import InputExample from paddlehub.dataset.dataset import InputExample
from .batching import pad_batch_data from .batching import pad_batch_data
import paddlehub as hub import paddlehub as hub
...@@ -527,7 +528,7 @@ class LACClassifyReader(object): ...@@ -527,7 +528,7 @@ class LACClassifyReader(object):
] ]
if len(processed) == 0: if len(processed) == 0:
if six.PY2: if six.PY2:
text = text.encode(sys.stdout.encoding) text = text.encode(sys_stdout_encoding())
logger.warning( logger.warning(
"The words in text %s can't be found in the vocabulary." % "The words in text %s can't be found in the vocabulary." %
(text)) (text))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册