提交 f17f924e 编写于 作者: B barrierye

fix bug in dag executor

上级 c045eebb
......@@ -15,6 +15,7 @@
import threading
import multiprocessing
import sys
import copy
if sys.version_info.major == 2:
import Queue
elif sys.version_info.major == 3:
......@@ -58,8 +59,8 @@ class DAGExecutor(object):
self._profiler = TimeProfiler()
self._profiler.enable(use_profile)
self._dag = DAG(response_op, self._profiler, use_profile,
use_multithread, client_type, channel_size)
self._dag = DAG(response_op, use_profile, use_multithread, client_type,
channel_size)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -75,8 +76,9 @@ class DAGExecutor(object):
self._id_lock = threading.Lock()
self._id_counter = 0
self._reset_max_id = 1000000000000000000
self._cv = threading.Condition()
self._fetch_buffer = {}
self._cv_pool = {}
self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = None
self._is_run = False
self._recive_func = None
......@@ -115,6 +117,7 @@ class DAGExecutor(object):
self._out_channel = out_channel
def _recive_out_channel_func(self):
cv = None
while self._is_run:
channeldata_dict = self._out_channel.front(self.name)
if len(channeldata_dict) != 1:
......@@ -125,18 +128,25 @@ class DAGExecutor(object):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
format(type(channeldata))))
with self._cv:
data_id = channeldata.id
self._fetch_buffer[data_id] = channeldata
self._cv.notify_all()
data_id = channeldata.id
_LOGGER.debug("recive thread fetch data: {}".format(data_id))
with self._cv_for_cv_pool:
cv = self._cv_pool[data_id]
with cv:
self._fetch_buffer = channeldata
cv.notify_all()
def _get_channeldata_from_fetch_buffer(self, data_id):
resp = None
with self._cv:
while data_id not in self._fetch_buffer:
self._cv.wait()
resp = self._fetch_buffer.pop(data_id)
self._cv.notify_all()
cv = threading.Condition()
with self._cv_for_cv_pool:
self._cv_pool[data_id] = cv
with cv:
cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer)
# cv.notify_all()
return resp
def _pack_channeldata(self, rpc_request, data_id):
......@@ -204,8 +214,8 @@ class DAGExecutor(object):
class DAG(object):
def __init__(self, response_op, profiler, use_profile, use_multithread,
client_type, channel_size):
def __init__(self, response_op, use_profile, use_multithread, client_type,
channel_size):
self._response_op = response_op
self._use_profile = use_profile
self._use_multithread = use_multithread
......@@ -213,7 +223,6 @@ class DAG(object):
self._client_type = client_type
if not self._use_multithread:
self._manager = multiprocessing.Manager()
self._profiler = profiler
def get_use_ops(self, response_op):
unique_names = set()
......@@ -413,7 +422,7 @@ class DAG(object):
def start(self):
self._threads_or_proces = []
for op in self._actual_ops:
op.init_profiler(self._profiler, self._use_profile)
op.use_profiler(self._use_profile)
if self._use_multithread:
self._threads_or_proces.extend(
op.start_with_thread(self._client_type))
......
......@@ -20,6 +20,7 @@ from concurrent import futures
import logging
import func_timeout
import os
import sys
from numpy import *
from .proto import pipeline_service_pb2
......@@ -59,14 +60,14 @@ class Op(object):
self._retry = max(1, retry)
self._input = None
self._outputs = []
self._profiler = None
self._use_profile = False
# only for multithread
self._for_init_op_lock = threading.Lock()
self._succ_init_op = False
def init_profiler(self, profiler, use_profile):
self._profiler = profiler
def use_profiler(self, use_profile):
self._use_profile = use_profile
def _profiler_record(self, string):
......@@ -351,9 +352,8 @@ class Op(object):
os._exit(-1)
# init profiler
if not use_multithread:
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._is_run = True
while self._is_run:
......@@ -400,11 +400,17 @@ class Op(object):
output_channels)
continue
if self._use_profile:
profile_str = self._profiler.gen_profile_str()
sys.stderr.write(profile_str)
#TODO
#output_data.add_profile(profile_str)
# push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels)
#self._profiler_record("push#{}_1".format(op_info_prefix))
self._profiler.print_profile()
#self._profiler.print_profile()
def _log(self, info):
return "{} {}".format(self.name, info)
......
......@@ -50,6 +50,9 @@ class TimeProfiler(object):
self._time_record.put((name, tag, timestamp))
def print_profile(self):
sys.stderr.write(self.gen_profile_str())
def gen_profile_str(self):
if self._enable is False:
return
print_str = self._print_head
......@@ -64,7 +67,7 @@ class TimeProfiler(object):
else:
tmp[name] = (tag, timestamp)
print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str)
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
return print_str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册