未验证 提交 3592ba8c 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][py2] remove `six` package (part2) (#47334)

* [CodeStyle][py2] remove `six` package (part2)

* six.ensure_str

* remove unused `import six`

* remove six from BUILTIN_LIKELY_MODULES

* remove six in example code

* remove some decode

* try to fix example code

* fix MockEtcdClient get/get_prefix returns data type

* fix MockEtcdClient get_prefix returns data

* fix MockEtcdClient get returns data

* remove `six` in pypi and conda requirements

* fix MockEtcdClient add_watch_callback/add_watch_prefix_callback returns data type

* refine MockEtcdClient
上级 3097a66d
...@@ -55,7 +55,6 @@ requirements: ...@@ -55,7 +55,6 @@ requirements:
- protobuf>=3.1.0 - protobuf>=3.1.0
- gast==0.3.3 - gast==0.3.3
- Pillow - Pillow
- six
- decorator - decorator
- astor - astor
""" """
...@@ -67,7 +66,6 @@ requirements: ...@@ -67,7 +66,6 @@ requirements:
- protobuf>=3.1.0 - protobuf>=3.1.0
- gast==0.3.3 - gast==0.3.3
- Pillow - Pillow
- six
- decorator - decorator
- astor - astor
""" """
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import time import time
import socket import socket
import os import os
import six
import copy import copy
import signal import signal
import random import random
...@@ -244,8 +243,7 @@ class ElasticManager(object): ...@@ -244,8 +243,7 @@ class ElasticManager(object):
# register callback # register callback
def host_call_back(event): def host_call_back(event):
self.hosts = [ self.hosts = [
six.ensure_str(i[0]) i[0].decode() for i in self.etcd.get_prefix(self.node_prefix)
for i in self.etcd.get_prefix(self.node_prefix)
] ]
self.hosts = list(set(self.hosts)) if self.hosts else self.hosts self.hosts = list(set(self.hosts)) if self.hosts else self.hosts
logger.info( logger.info(
...@@ -266,7 +264,7 @@ class ElasticManager(object): ...@@ -266,7 +264,7 @@ class ElasticManager(object):
host_lease.refresh() host_lease.refresh()
hosts = [ hosts = [
six.ensure_str(i[0]) i[0].decode()
for i in self.etcd.get_prefix(self.node_prefix) for i in self.etcd.get_prefix(self.node_prefix)
] ]
hosts = list(set(hosts)) if hosts else hosts hosts = list(set(hosts)) if hosts else hosts
...@@ -311,7 +309,8 @@ class ElasticManager(object): ...@@ -311,7 +309,8 @@ class ElasticManager(object):
def endpoints_call_back(event): def endpoints_call_back(event):
if not self.dist_endpoints: if not self.dist_endpoints:
return return
edps = six.ensure_str(self.etcd.get(self.endpoints_path)[0] or '') value = self.etcd.get(self.endpoints_path)[0]
edps = value.decode() if value is not None else ''
self.dist_endpoints, self.trainers = edps.split('|') self.dist_endpoints, self.trainers = edps.split('|')
logger.info( logger.info(
"set DISTRIBUTED_TRAINER_ENDPOINTS {} ".format( "set DISTRIBUTED_TRAINER_ENDPOINTS {} ".format(
...@@ -426,8 +425,7 @@ class ElasticManager(object): ...@@ -426,8 +425,7 @@ class ElasticManager(object):
self.hosts = host_list self.hosts = host_list
else: else:
self.hosts = [ self.hosts = [
six.ensure_str(i[0]) i[0].decode() for i in self.etcd.get_prefix(self.node_prefix)
for i in self.etcd.get_prefix(self.node_prefix)
] ]
self.hosts = list(set(self.hosts)) if self.hosts else self.hosts self.hosts = list(set(self.hosts)) if self.hosts else self.hosts
......
...@@ -17,7 +17,6 @@ from paddle.distributed.launch.utils.kv_server import KVServer ...@@ -17,7 +17,6 @@ from paddle.distributed.launch.utils.kv_server import KVServer
import time import time
import sys import sys
import six
import threading import threading
import copy import copy
import random import random
...@@ -214,22 +213,22 @@ class ETCDMaster(Master): ...@@ -214,22 +213,22 @@ class ETCDMaster(Master):
if len(result) == size: if len(result) == size:
if rank < 0: if rank < 0:
keys = [six.ensure_str(i[1].key) for i in result] keys = [i[1].key.decode() for i in result]
sorted_keys = [six.ensure_str(i[1].key) for i in result] sorted_keys = [i[1].key.decode() for i in result]
sorted_keys.sort() sorted_keys.sort()
values = [six.ensure_str(i[0]) for i in result] values = [i[0].decode() for i in result]
ret = [values[keys.index(k)] for k in sorted_keys] ret = [values[keys.index(k)] for k in sorted_keys]
idx = ret.index(value) idx = ret.index(value)
return ret, idx return ret, idx
else: else:
ret = [None] * size ret = [None] * size
for v, k in result: for v, k in result:
ii = int(six.ensure_str(k.key).split('/')[-1]) ii = int(k.key.decode().split('/')[-1])
if ii < 0: if ii < 0:
self.ctx.logger.error( self.ctx.logger.error(
"rank {} error in sync".format(ii) "rank {} error in sync".format(ii)
) )
ret[ii] = six.ensure_str(v) ret[ii] = v.decode()
return ret, rank return ret, rank
else: else:
time.sleep(0.5) time.sleep(0.5)
...@@ -278,8 +277,7 @@ class ETCDMaster(Master): ...@@ -278,8 +277,7 @@ class ETCDMaster(Master):
def fetch_peer_alive(self): def fetch_peer_alive(self):
peer_alive = [ peer_alive = [
six.ensure_str(i[0]) i[0].decode() for i in self.client.get_prefix(self.heartbeat_prefix)
for i in self.client.get_prefix(self.heartbeat_prefix)
] ]
self.ctx.logger.debug("peer alive {}".format(peer_alive)) self.ctx.logger.debug("peer alive {}".format(peer_alive))
return peer_alive return peer_alive
...@@ -319,7 +317,8 @@ class ETCDMaster(Master): ...@@ -319,7 +317,8 @@ class ETCDMaster(Master):
), "set status failed {}".format(status) ), "set status failed {}".format(status)
def get_status(self): def get_status(self):
return six.ensure_str(self.client.get(self.job_prefix)[0] or '') value = self.client.get(self.job_prefix)[0]
return value.decode() if value is not None else ''
def stop(self): def stop(self):
if hasattr(self, 'beat_thread'): if hasattr(self, 'beat_thread'):
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
import abc import abc
import copy import copy
import math import math
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import six
import sys import sys
import time import time
import signal import signal
...@@ -284,9 +283,9 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -284,9 +283,9 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
except: except:
self._exit_thread_expectedly() self._exit_thread_expectedly()
except: except Exception as e:
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info()) raise e
self._exit_thread_expectedly() self._exit_thread_expectedly()
...@@ -334,7 +333,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -334,7 +333,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
except StopIteration: except StopIteration:
self._reader.shutdown() self._reader.shutdown()
self._try_shutdown_all() self._try_shutdown_all()
six.reraise(*sys.exc_info()) raise
finally: finally:
if in_profiler_mode(): if in_profiler_mode():
trace_event.end() trace_event.end()
...@@ -629,7 +628,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -629,7 +628,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._blocking_queue.close() self._blocking_queue.close()
except Exception as e: except Exception as e:
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info()) raise e
finally: finally:
self._rcvd_idx += 1 self._rcvd_idx += 1
...@@ -715,7 +714,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -715,7 +714,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
"DataLoader reader thread failed({}) to read data from " "DataLoader reader thread failed({}) to read data from "
"workers' result queue.".format(e) "workers' result queue.".format(e)
) )
six.reraise(*sys.exc_info()) raise e
else: else:
if self._dataset_kind == _DatasetKind.ITER and isinstance( if self._dataset_kind == _DatasetKind.ITER and isinstance(
data, _IterableDatasetStopIteration data, _IterableDatasetStopIteration
...@@ -850,7 +849,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -850,7 +849,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if not self._persistent_workers: if not self._persistent_workers:
self._reader.shutdown() self._reader.shutdown()
self._try_shutdown_all() self._try_shutdown_all()
six.reraise(*sys.exc_info()) raise
finally: finally:
if in_profiler_mode(): if in_profiler_mode():
trace_event.end() trace_event.end()
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import six
import sys import sys
import paddle import paddle
import numpy as np import numpy as np
...@@ -395,7 +394,7 @@ def _worker_loop( ...@@ -395,7 +394,7 @@ def _worker_loop(
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass pass
except: except:
six.reraise(*sys.exc_info()) raise
finally: finally:
if use_shared_memory: if use_shared_memory:
_cleanup_mmap() _cleanup_mmap()
...@@ -22,7 +22,6 @@ import re ...@@ -22,7 +22,6 @@ import re
import types import types
import numpy import numpy
import six
import builtins import builtins
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
...@@ -58,7 +57,6 @@ BUILTIN_LIKELY_MODULES = [ ...@@ -58,7 +57,6 @@ BUILTIN_LIKELY_MODULES = [
copy, copy,
inspect, inspect,
re, re,
six,
numpy, numpy,
logging, logging,
] ]
......
...@@ -19,7 +19,6 @@ import sys ...@@ -19,7 +19,6 @@ import sys
import warnings import warnings
import numpy as np import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import six
from .data_feeder import convert_dtype from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, Operator from .framework import Program, default_main_program, Variable, Operator
from .framework import convert_np_dtype_to_dtype_, _apply_pass from .framework import convert_np_dtype_to_dtype_, _apply_pass
...@@ -1574,23 +1573,20 @@ class Executor(object): ...@@ -1574,23 +1573,20 @@ class Executor(object):
] ]
self._log_force_set_program_cache(use_program_cache) self._log_force_set_program_cache(use_program_cache)
try: res = self._run_impl(
res = self._run_impl( program=program,
program=program, feed=feed,
feed=feed, fetch_list=fetch_list,
fetch_list=fetch_list, feed_var_name=feed_var_name,
feed_var_name=feed_var_name, fetch_var_name=fetch_var_name,
fetch_var_name=fetch_var_name, scope=scope,
scope=scope, return_numpy=return_numpy,
return_numpy=return_numpy, use_program_cache=use_program_cache,
use_program_cache=use_program_cache, use_prune=use_prune,
use_prune=use_prune, return_merged=return_merged,
return_merged=return_merged, )
) core.update_autotune_status()
core.update_autotune_status() return res
return res
except Exception as e:
six.reraise(*sys.exc_info())
def _run_impl( def _run_impl(
self, self,
......
...@@ -5777,10 +5777,10 @@ class Program(object): ...@@ -5777,10 +5777,10 @@ class Program(object):
.. code-block:: python .. code-block:: python
import six import paddle
def print_prog(prog): def print_prog(prog):
for name, value in sorted(six.iteritems(prog.block(0).vars)): for name, value in sorted(prog.block(0).vars.items()):
print(value) print(value)
for op in prog.block(0).ops: for op in prog.block(0).ops:
print("op type is {}".format(op.type)) print("op type is {}".format(op.type))
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import warnings import warnings
import logging import logging
import six
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.core import CommContext from paddle.fluid.core import CommContext
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import multiprocessing import multiprocessing
import os import os
import six
import sys import sys
import threading import threading
...@@ -523,10 +522,10 @@ def _py_reader( ...@@ -523,10 +522,10 @@ def _py_reader(
if reader.exited: if reader.exited:
break break
feed_queue.close() feed_queue.close()
except Exception as ex: except Exception as e:
feed_queue.kill() feed_queue.kill()
logging.warn('Your decorated reader has raised an exception!') logging.warn('Your decorated reader has raised an exception!')
six.reraise(*sys.exc_info()) raise e
reader.thread = threading.Thread( reader.thread = threading.Thread(
target=__provider_thread__, args=(_current_expected_place(),) target=__provider_thread__, args=(_current_expected_place(),)
......
...@@ -14998,7 +14998,6 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -14998,7 +14998,6 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
# example 1: # example 1:
import paddle import paddle
import six
import numpy as np import numpy as np
paddle.enable_static() paddle.enable_static()
...@@ -15024,7 +15023,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -15024,7 +15023,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
def simple_net(img, label): def simple_net(img, label):
hidden = img hidden = img
for idx in six.moves.range(4): for idx in range(4):
hidden = paddle.static.nn.fc(hidden, size=200) hidden = paddle.static.nn.fc(hidden, size=200)
new_hidden = create_tmp_var(name='hidden_{}'.format(idx), new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
dtype=hidden.dtype, shape=hidden.shape) dtype=hidden.dtype, shape=hidden.shape)
...@@ -15042,13 +15041,13 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -15042,13 +15041,13 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
return ce_loss(prediction, label) return ce_loss(prediction, label)
x = paddle.static.data(name='x', shape=[1,4], dtype='float32') x = paddle.static.data(name='x', shape=[1,4], dtype='float32')
y = paddle.static.data(name='y', shape=[1,10], dtype='int64') y = paddle.static.data(name='y', shape=[1], dtype='int64')
res = simple_net(x, y) res = simple_net(x, y)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
input1 = np.random.random(size=[1,4]).astype('float32') input1 = np.random.random(size=[1,4]).astype('float32')
input2 = np.random.randint(1, 10, size=[1,10], dtype='int64') input2 = np.random.randint(1, 10, size=[1], dtype='int64')
out = exe.run(paddle.static.default_main_program(), out = exe.run(paddle.static.default_main_program(),
feed={'x':input1, 'y':input2}, feed={'x':input1, 'y':input2},
fetch_list=[res.name]) fetch_list=[res.name])
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import collections import collections
import copy import copy
import six
import numpy as np import numpy as np
from ..framework import Block, Variable, _non_static_mode from ..framework import Block, Variable, _non_static_mode
from ..data_feeder import ( from ..data_feeder import (
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from . import core from . import core
import sys import sys
import six
import numpy as np import numpy as np
import threading import threading
import paddle import paddle
...@@ -143,7 +142,7 @@ def _reader_process_loop(batch_reader, data_queue): ...@@ -143,7 +142,7 @@ def _reader_process_loop(batch_reader, data_queue):
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass pass
except: except:
six.reraise(*sys.exc_info()) raise
class DataLoaderBase(object): class DataLoaderBase(object):
...@@ -1202,7 +1201,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1202,7 +1201,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
return self._reader.read_next_var_list() return self._reader.read_next_var_list()
except StopIteration: except StopIteration:
self._reset() self._reset()
six.reraise(*sys.exc_info()) raise
def _exit_thread_expectedly(self): def _exit_thread_expectedly(self):
self._thread_done_event.set() self._thread_done_event.set()
...@@ -1232,7 +1231,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1232,7 +1231,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
# start trying to get data from queue. At this time, the child thread needs # start trying to get data from queue. At this time, the child thread needs
# to wait slightly longer # to wait slightly longer
tensor_list = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT) tensor_list = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
except: except Exception as e:
# NOTE [ avoid handing ] After adding the shared memory mechanism, not only # NOTE [ avoid handing ] After adding the shared memory mechanism, not only
# the queue.Empty exception will occur here, but other exceptions will also # the queue.Empty exception will occur here, but other exceptions will also
# occur, such as mmap failure. If it is not handled here, it will hang. # occur, such as mmap failure. If it is not handled here, it will hang.
...@@ -1240,7 +1239,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1240,7 +1239,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
logging.error( logging.error(
"DataLoader reader thread failed to read data from the multiprocessing.Queue." "DataLoader reader thread failed to read data from the multiprocessing.Queue."
) )
six.reraise(*sys.exc_info()) raise e
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
if tensor_list is not None: if tensor_list is not None:
...@@ -1250,9 +1249,9 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1250,9 +1249,9 @@ class DygraphGeneratorLoader(DataLoaderBase):
array.append(tensor) array.append(tensor)
if not self._blocking_queue.push(array): if not self._blocking_queue.push(array):
self._blocking_queue.close() self._blocking_queue.close()
except: except Exception as e:
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info()) raise e
else: else:
self._exit_thread_expectedly() self._exit_thread_expectedly()
...@@ -1278,13 +1277,13 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1278,13 +1277,13 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._blocking_queue.close() self._blocking_queue.close()
self._thread = None self._thread = None
except Exception: except Exception as e:
self._blocking_queue.kill() self._blocking_queue.kill()
self._thread = None self._thread = None
logging.warning( logging.warning(
"DygraphDataLoader reader thread raised an exception." "DygraphDataLoader reader thread raised an exception."
) )
six.reraise(*sys.exc_info()) raise e
def set_sample_generator( def set_sample_generator(
self, reader, batch_size, drop_last=True, places=None self, reader, batch_size, drop_last=True, places=None
...@@ -1510,7 +1509,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1510,7 +1509,7 @@ class GeneratorLoader(DataLoaderBase):
except StopIteration: except StopIteration:
self._queue.close() self._queue.close()
self._reset() self._reset()
six.reraise(*sys.exc_info()) raise
def start(self): def start(self):
assert ( assert (
...@@ -1551,11 +1550,11 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1551,11 +1550,11 @@ class GeneratorLoader(DataLoaderBase):
self._queue.close() self._queue.close()
self._thread = None self._thread = None
except Exception as ex: except Exception as e:
self._queue.kill() self._queue.kill()
self._thread = None self._thread = None
logging.warning('Your reader has raised an exception!') logging.warning('Your reader has raised an exception!')
six.reraise(*sys.exc_info()) raise e
self._thread = threading.Thread( self._thread = threading.Thread(
target=__thread_main__, args=(_current_expected_place(),) target=__thread_main__, args=(_current_expected_place(),)
......
...@@ -26,7 +26,6 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -26,7 +26,6 @@ from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
import numpy as np import numpy as np
import six
import pickle import pickle
import os import os
import errno import errno
......
...@@ -27,6 +27,16 @@ class MockLease: ...@@ -27,6 +27,16 @@ class MockLease:
pass pass
class MockKVMetadata:
def __init__(self, key):
self.key = key
self.create_revision = 2
self.mod_revision = 3
self.version = 2
self.lease_id = 0
self.response_header = None
class MockEtcdClient: class MockEtcdClient:
def __init__(self, lease=None): def __init__(self, lease=None):
self._lease = lease self._lease = lease
...@@ -35,28 +45,30 @@ class MockEtcdClient: ...@@ -35,28 +45,30 @@ class MockEtcdClient:
pass pass
def get(self, key): def get(self, key):
value = "0" return b'0', MockKVMetadata(b"/prefix")
return value, value
def delete_prefix(self, key): def delete_prefix(self, key):
pass pass
def get_prefix(self, key_prefix): def get_prefix(self, key_prefix):
hosts = ["10.10.10.1:6001", "10.10.10.2:6001"] hosts = [
return hosts (b"/prefix/host1", b"10.10.10.1:6001"),
(b"/prefix/host2", b"10.10.10.2:6001"),
]
return ((v, MockKVMetadata(k)) for k, v in hosts)
def add_watch_callback(self, *args, **kwargs): def add_watch_callback(self, *args, **kwargs):
return "host_watch" return 0
def add_watch_prefix_callback(self, key_prefix, callback, **kwargs): def add_watch_prefix_callback(self, key_prefix, callback, **kwargs):
callback(None) callback(None)
return "host_watch" return 0
def cancel_watch(self, watch_id): def cancel_watch(self, watch_id):
pass pass
def delete(self, key): def delete(self, key):
pass return True
def lease(self, ttl): def lease(self, ttl):
if self._lease: if self._lease:
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from threading import Thread from threading import Thread
import multiprocessing import multiprocessing
import six
import sys import sys
import warnings import warnings
import logging import logging
...@@ -610,9 +609,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -610,9 +609,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
raise ValueError("sample has None") raise ValueError("sample has None")
queue.put(sample) queue.put(sample)
queue.put(None) queue.put(None)
except: except Exception as e:
queue.put("") queue.put("")
six.reraise(*sys.exc_info()) raise e
def queue_reader(): def queue_reader():
queue = fork_context.Queue(queue_size) queue = fork_context.Queue(queue_size)
...@@ -627,11 +626,11 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -627,11 +626,11 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
while finish_num < reader_num: while finish_num < reader_num:
try: try:
sample = queue.get(timeout=QUEUE_GET_TIMEOUT) sample = queue.get(timeout=QUEUE_GET_TIMEOUT)
except: except Exception as e:
logging.error( logging.error(
"multiprocess_reader failed to get data from the multiprocessing.Queue." "multiprocess_reader failed to get data from the multiprocessing.Queue."
) )
six.reraise(*sys.exc_info()) raise e
if sample is None: if sample is None:
finish_num += 1 finish_num += 1
...@@ -650,10 +649,10 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -650,10 +649,10 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
conn.send(json.dumps(sample)) conn.send(json.dumps(sample))
conn.send(json.dumps(None)) conn.send(json.dumps(None))
conn.close() conn.close()
except: except Exception as e:
conn.send(json.dumps("")) conn.send(json.dumps(""))
conn.close() conn.close()
six.reraise(*sys.exc_info()) raise e
def pipe_reader(): def pipe_reader():
conns = [] conns = []
......
...@@ -2,7 +2,6 @@ requests>=2.20.0 ...@@ -2,7 +2,6 @@ requests>=2.20.0
numpy>=1.13 numpy>=1.13
protobuf>=3.1.0, <=3.20.0 protobuf>=3.1.0, <=3.20.0
Pillow Pillow
six
decorator decorator
astor astor
paddle_bfloat==0.1.7 paddle_bfloat==0.1.7
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册